summaryrefslogtreecommitdiff
path: root/cnn_v1/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v1/training/train_cnn.py')
-rwxr-xr-xcnn_v1/training/train_cnn.py943
1 files changed, 943 insertions, 0 deletions
diff --git a/cnn_v1/training/train_cnn.py b/cnn_v1/training/train_cnn.py
new file mode 100755
index 0000000..4171dcb
--- /dev/null
+++ b/cnn_v1/training/train_cnn.py
@@ -0,0 +1,943 @@
+#!/usr/bin/env python3
+"""
+CNN Training Script for Image-to-Image Transformation
+
+Trains a convolutional neural network on multiple input/target image pairs.
+
+Usage:
+ # Training
+ python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+
+ # Inference (generate ground truth)
+ python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
+
+Example:
+ python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
+ python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from PIL import Image
+import numpy as np
+import cv2
+import os
+import sys
+import argparse
+import glob
+
+
+class ImagePairDataset(Dataset):
+ """Dataset for loading matching input/target image pairs"""
+
+ def __init__(self, input_dir, target_dir, transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.transform = transform
+
+ # Find all images in input directory
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ # Try to find matching target with same name but any supported extension
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} matching image pairs")
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+ def __getitem__(self, idx):
+ input_path, target_path = self.image_pairs[idx]
+
+ # Load RGBD input (4 channels: RGB + Depth)
+ input_img = Image.open(input_path).convert('RGBA')
+ target_img = Image.open(target_path).convert('RGB')
+
+ if self.transform:
+ input_img = self.transform(input_img)
+ target_img = self.transform(target_img)
+
+ return input_img, target_img
+
+
+class PatchDataset(Dataset):
+ """Dataset for extracting salient patches from image pairs"""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris', transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+ self.transform = transform
+
+ # Find all image pairs
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.image_pairs) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.image_pairs) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points using specified detector"""
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ if self.detector == 'harris':
+ # Harris corner detection
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ # FAST feature detection
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ # Shi-Tomasi corner detection (goodFeaturesToTrack with different params)
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ # High-gradient regions
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+
+ # Find top gradient locations
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ else:
+ raise ValueError(f"Unknown detector: {self.detector}")
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners (within bounds)
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ input_path, target_path = self.image_pairs[img_idx]
+
+ # Load images
+ input_img = Image.open(input_path).convert('RGBA')
+ target_img = Image.open(target_path).convert('RGB')
+
+ # Detect salient points (use input image for detection)
+ input_array = np.array(input_img)[:, :, :3] # Use RGB for detection
+ corners = self._detect_salient_points(input_array)
+
+ # Extract patch at specified index
+ x, y = corners[patch_idx]
+ half_patch = self.patch_size // 2
+
+ # Crop patches
+ input_patch = input_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+ target_patch = target_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+
+ if self.transform:
+ input_patch = self.transform(input_patch)
+ target_patch = self.transform(target_patch)
+
+ return input_patch, target_patch
+
+
+class SimpleCNN(nn.Module):
+ """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray)
+
+ Internally computes grayscale, expands to 3-channel RGB output.
+ """
+
+ def __init__(self, num_layers=1, kernel_sizes=None):
+ super(SimpleCNN, self).__init__()
+
+ if kernel_sizes is None:
+ kernel_sizes = [3] * num_layers
+
+ assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
+
+ self.kernel_sizes = kernel_sizes
+ self.layers = nn.ModuleList()
+
+ for i, kernel_size in enumerate(kernel_sizes):
+ padding = kernel_size // 2
+ if i < num_layers - 1:
+ # Inner layers: 7→4 (RGBD output)
+ self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
+ else:
+ # Final layer: 7→1 (grayscale output)
+ self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
+
+ def forward(self, x, return_intermediates=False):
+ # x: [B,4,H,W] - RGBD input (D = 1/z)
+ B, C, H, W = x.shape
+
+ intermediates = [] if return_intermediates else None
+
+ # Normalize RGBD to [-1,1]
+ x_norm = (x - 0.5) * 2.0
+
+ # Compute normalized coordinates [-1,1]
+ y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
+
+ # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
+ gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
+ gray = (gray - 0.5) * 2.0 # [-1,1]
+
+ # Layer 0
+ layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
+ out = self.layers[0](layer0_input) # [B,4,H,W]
+ out = torch.tanh(out) # [-1,1]
+ if return_intermediates:
+ intermediates.append(out.clone())
+
+ # Inner layers
+ for i in range(1, len(self.layers)-1):
+ layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[i](layer_input)
+ out = torch.tanh(out)
+ if return_intermediates:
+ intermediates.append(out.clone())
+
+ # Final layer (grayscale→RGB)
+ final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[-1](final_input) # [B,1,H,W] grayscale
+ out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
+ final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB
+
+ if return_intermediates:
+ return final_out, intermediates
+ return final_out
+
+
+def generate_layer_shader(output_path, num_layers, kernel_sizes):
+ """Generate cnn_layer.wgsl with proper layer switches"""
+
+ with open(output_path, 'w') as f:
+ f.write("// CNN layer shader - uses modular convolution snippets\n")
+ f.write("// Supports multi-pass rendering with residual connections\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+ f.write("@group(0) @binding(0) var smplr: sampler;\n")
+ f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n")
+ f.write("#include \"common_uniforms\"\n")
+ f.write("#include \"cnn_activation\"\n")
+
+ # Include necessary conv functions
+ conv_sizes = set(kernel_sizes)
+ for ks in sorted(conv_sizes):
+ f.write(f"#include \"cnn_conv{ks}x{ks}\"\n")
+ f.write("#include \"cnn_weights_generated\"\n\n")
+
+ f.write("struct CNNLayerParams {\n")
+ f.write(" layer_index: i32,\n")
+ f.write(" blend_amount: f32,\n")
+ f.write(" _pad: vec2<f32>,\n")
+ f.write("};\n\n")
+ f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n")
+ f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n")
+ f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n")
+ f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n")
+ f.write(" var pos = array<vec2<f32>, 3>(\n")
+ f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n")
+ f.write(" );\n")
+ f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n")
+ f.write("}\n\n")
+ f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
+ f.write(" // Match PyTorch linspace\n")
+ f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n")
+ f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
+ f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
+ f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n")
+ f.write(" var result = vec4<f32>(0.0);\n\n")
+
+ # Generate layer switches
+ for layer_idx in range(num_layers):
+ is_final = layer_idx == num_layers - 1
+ ks = kernel_sizes[layer_idx]
+ conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
+
+ if layer_idx == 0:
+ conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src"
+ f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n")
+ f.write(f" if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" }}\n")
+ elif not is_final:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
+ f.write(f" }}\n")
+ else:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" let sum = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
+ f.write(f" let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation\n")
+ f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
+ f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
+ f.write(f" }}\n")
+
+ f.write(" return result; // [-1,1]\n")
+ f.write("}\n")
+
+
+def export_weights_to_wgsl(model, output_path, kernel_sizes):
+ """Export trained weights to WGSL format (vec4-optimized)"""
+
+ with open(output_path, 'w') as f:
+ f.write("// Auto-generated CNN weights (vec4-optimized)\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+
+ for i, layer in enumerate(model.layers):
+ weights = layer.weight.data.cpu().numpy()
+ bias = layer.bias.data.cpu().numpy()
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ is_final = (i == len(model.layers) - 1)
+
+ if is_final:
+ # Final layer: 7→1, structure: array<vec4<f32>, 18> (9 pos × 2 vec4)
+ # Input: [rgba, uv_gray_1] → 2 vec4s per position
+ f.write(f"const weights_layer{i}: array<vec4<f32>, {num_positions * 2}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ # First vec4: [w0, w1, w2, w3] (rgba)
+ v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)]
+ # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
+ v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)]
+ v1.append(f"{bias[0] / num_positions:.6f}")
+ f.write(f" vec4<f32>({', '.join(v0)}),\n")
+ f.write(f" vec4<f32>({', '.join(v1)})")
+ f.write(",\n" if pos < num_positions-1 else "\n")
+ f.write(");\n\n")
+ else:
+ # Inner layers: 7→4, structure: array<vec4<f32>, 72> (36 entries × 2 vec4)
+ # Each filter: 2 vec4s for [rgba][uv_gray_1] inputs
+ num_vec4s = num_positions * 4 * 2
+ f.write(f"const weights_layer{i}: array<vec4<f32>, {num_vec4s}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ for out_c in range(4):
+ # First vec4: [w0, w1, w2, w3] (rgba)
+ v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)]
+ # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
+ v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)]
+ v1.append(f"{bias[out_c] / num_positions:.6f}")
+ idx = (pos * 4 + out_c) * 2
+ f.write(f" vec4<f32>({', '.join(v0)}),\n")
+ f.write(f" vec4<f32>({', '.join(v1)})")
+ f.write(",\n" if idx < num_vec4s-2 else "\n")
+ f.write(");\n\n")
+
+
+def generate_conv_base_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+
+ # Accumulate
+ f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
+ f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
+ f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
+ f.write(f" pos += 8;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def generate_conv_src_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Source layer: 7→4 channels (vec4-optimized)\n")
+ f.write(f"// Normalizes [0,1] input to [-1,1] internally\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n\n")
+
+ # Normalize center pixel for gray channel
+ f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
+ f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
+
+ # Accumulate with dot products (unrolled)
+ f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
+ f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
+ f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
+ f.write(f" pos += 8;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def generate_conv_final_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to1() function for final layer (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Final layer: 7→1 channel (vec4-optimized)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"// Returns raw sum (activation applied at call site)\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to1(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 2}>\n")
+ f.write(f") -> f32 {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+ f.write(f" var sum = 0.0;\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n")
+
+ # Accumulate with dot products
+ f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" pos += 2;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def train(args):
+ """Main training loop"""
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Prepare dataset
+ if args.patch_size:
+ # Patch-based training (preserves natural scale)
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector,
+ transform=transform)
+ else:
+ # Full-image training (resize mode)
+ transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ ])
+ dataset = ImagePairDataset(args.input, args.target, transform=transform)
+
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Parse kernel sizes
+ kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+ if len(kernel_sizes) == 1 and args.layers > 1:
+ kernel_sizes = kernel_sizes * args.layers
+
+ # Create model
+ model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)
+
+ # Loss and optimizer
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ # Resume from checkpoint
+ start_epoch = 0
+ if args.resume:
+ if os.path.exists(args.resume):
+ print(f"Loading checkpoint from {args.resume}...")
+ checkpoint = torch.load(args.resume, map_location=device)
+ model.load_state_dict(checkpoint['model_state'])
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
+ start_epoch = checkpoint['epoch'] + 1
+ print(f"Resumed from epoch {start_epoch}")
+ else:
+ print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
+
+ # Compute valid center region (exclude conv padding borders)
+ num_layers = args.layers
+ border = num_layers # Each 3x3 layer needs 1px, accumulates across layers
+
+ # Early stopping setup
+ loss_history = []
+ early_stop_triggered = False
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
+ print(f"Computing loss on center region only (excluding {border}px border)")
+ if args.early_stop_patience > 0:
+ print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}")
+
+ for epoch in range(start_epoch, args.epochs):
+ epoch_loss = 0.0
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
+ inputs, targets = inputs.to(device), targets.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(inputs)
+
+ # Only compute loss on center pixels with valid neighborhoods
+ if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border:
+ outputs_center = outputs[:, :, border:-border, border:-border]
+ targets_center = targets[:, :, border:-border, border:-border]
+ loss = criterion(outputs_center, targets_center)
+ else:
+ loss = criterion(outputs, targets)
+
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+ if (epoch + 1) % 10 == 0:
+ print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+
+ # Early stopping check
+ if args.early_stop_patience > 0:
+ loss_history.append(avg_loss)
+ if len(loss_history) >= args.early_stop_patience:
+ oldest_loss = loss_history[-args.early_stop_patience]
+ loss_change = abs(avg_loss - oldest_loss)
+ if loss_change < args.early_stop_eps:
+ print(f"Early stopping triggered at epoch {epoch+1}")
+ print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}")
+ early_stop_triggered = True
+ break
+
+ # Save checkpoint
+ if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
+ checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
+ torch.save({
+ 'epoch': epoch,
+ 'model_state': model.state_dict(),
+ 'optimizer_state': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.layers
+ }, checkpoint_path)
+ print(f"Saved checkpoint to {checkpoint_path}")
+
+ # Export weights and shader
+ output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"\nExporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, args.layers, kernel_sizes)
+
+ # Generate conv shader files for all kernel sizes
+ for ks in set(kernel_sizes):
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
+ if not os.path.exists(conv_path):
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
+ continue
+
+ # File exists, check for missing functions
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer if missing
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant to {conv_path}")
+
+ print("Training complete!")
+
+
+def export_from_checkpoint(checkpoint_path, output_path=None):
+ """Export WGSL files from checkpoint without training"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint file '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ kernel_sizes = checkpoint['kernel_sizes']
+ num_layers = checkpoint['num_layers']
+
+ # Recreate model
+ model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes)
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Export weights
+ output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"Exporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, num_layers, kernel_sizes)
+
+ # Generate conv shader files for all kernel sizes
+ for ks in set(kernel_sizes):
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
+ if not os.path.exists(conv_path):
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
+ continue
+
+ # File exists, check for missing functions
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer if missing
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant to {conv_path}")
+
+ print("Export complete!")
+
+
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
+ """Run sliding-window inference to match WGSL shader behavior
+
+ Outputs RGBA PNG (RGB from model + alpha from input).
+ """
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ if not os.path.exists(input_path):
+ print(f"Error: Input image '{input_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # Reconstruct model
+ model = SimpleCNN(
+ num_layers=checkpoint['num_layers'],
+ kernel_sizes=checkpoint['kernel_sizes']
+ )
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Debug: Zero out all weights and biases
+ if zero_weights:
+ print("DEBUG: Zeroing out all weights and biases")
+ for layer in model.layers:
+ with torch.no_grad():
+ layer.weight.zero_()
+ layer.bias.zero_()
+
+ model.eval()
+
+ # Load image
+ print(f"Loading input image: {input_path}")
+ img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+ W, H = img.size
+
+ # Process full image with sliding window (matches WGSL shader)
+ print(f"Processing full image ({W}×{H}) with sliding window...")
+ with torch.no_grad():
+ if save_intermediates:
+ output_tensor, intermediates = model(img_tensor, return_intermediates=True)
+ else:
+ output_tensor = model(img_tensor) # [1,3,H,W] RGB
+
+ # Convert to numpy and append alpha
+ output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB
+ alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input
+ output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ output_u8 = (output_rgba * 255).astype(np.uint8)
+ print("First 8 pixels (RGBA hex):")
+ for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])):
+ y, x = i // output_u8.shape[1], i % output_u8.shape[1]
+ r, g, b, a = output_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+
+ # Save final output as RGBA
+ print(f"Saving output to: {output_path}")
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
+ output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
+ output_img.save(output_path)
+
+ # Save intermediates if requested
+ if save_intermediates:
+ os.makedirs(save_intermediates, exist_ok=True)
+ print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}")
+ for layer_idx, layer_tensor in enumerate(intermediates):
+ # Convert [-1,1] to [0,1] for visualization
+ layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
+ layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8)
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ print(f"Layer {layer_idx} first 8 pixels (RGBA hex):")
+ for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])):
+ y, x = i // layer_u8.shape[1], i % layer_u8.shape[1]
+ if layer_u8.shape[2] == 4:
+ r, g, b, a = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+ else:
+ r, g, b = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}")
+
+ # Save all 4 channels for intermediate layers
+ if layer_data.shape[2] == 4:
+ layer_img = Image.fromarray(layer_u8, mode='RGBA')
+ else:
+ layer_img = Image.fromarray(layer_u8)
+ layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
+ layer_img.save(layer_path)
+ print(f" Saved layer {layer_idx} to {layer_path}")
+
+ print("Done!")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
+ parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
+ parser.add_argument('--target', help='Target image directory')
+ parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
+ parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
+ parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
+ parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
+ parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
+ parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
+ parser.add_argument('--resume', help='Resume from checkpoint file')
+ parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
+ parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
+ parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)')
+ parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)')
+ parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Salient point detector for patch extraction (default: harris)')
+ parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
+ parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
+ parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
+ parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)')
+ parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)')
+
+ args = parser.parse_args()
+
+ # Inference mode
+ if args.infer:
+ checkpoint = args.export_only
+ if not checkpoint:
+ print("Error: --infer requires --export-only <checkpoint>")
+ sys.exit(1)
+ output_path = args.output or 'inference_output.png'
+ patch_size = args.patch_size or 32
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex)
+ return
+
+ # Export-only mode
+ if args.export_only:
+ export_from_checkpoint(args.export_only, args.output)
+ return
+
+ # Validate directories for training
+ if not args.input or not args.target:
+ print("Error: --input and --target required for training (or use --export-only)")
+ sys.exit(1)
+
+ if not os.path.isdir(args.input):
+ print(f"Error: Input directory '{args.input}' does not exist")
+ sys.exit(1)
+
+ if not os.path.isdir(args.target):
+ print(f"Error: Target directory '{args.target}' does not exist")
+ sys.exit(1)
+
+ train(args)
+
+
+if __name__ == "__main__":
+ main()