#!/usr/bin/env python3 """ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: # Training python3 train_cnn.py --input input_dir/ --target target_dir/ [options] # Inference (generate ground truth) python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import cv2 import os import sys import argparse import glob class ImagePairDataset(Dataset): """Dataset for loading matching input/target image pairs""" def __init__(self, input_dir, target_dir, transform=None): self.input_dir = input_dir self.target_dir = target_dir self.transform = transform # Find all images in input directory input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] self.image_pairs = [] for pattern in input_patterns: input_files = glob.glob(os.path.join(input_dir, pattern)) for input_path in input_files: filename = os.path.basename(input_path) # Try to find matching target with same name but any supported extension target_path = None for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: base_name = os.path.splitext(filename)[0] candidate = os.path.join(target_dir, f"{base_name}.{ext}") if os.path.exists(candidate): target_path = candidate break if target_path: self.image_pairs.append((input_path, target_path)) if not self.image_pairs: raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") print(f"Found {len(self.image_pairs)} matching image pairs") def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] # Load RGBD input (4 channels: RGB + Depth) input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') if self.transform: input_img = self.transform(input_img) target_img = self.transform(target_img) return input_img, target_img class PatchDataset(Dataset): """Dataset for extracting salient patches from image pairs""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, detector='harris', transform=None): self.input_dir = input_dir self.target_dir = target_dir self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector self.transform = transform # Find all image pairs input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] self.image_pairs = [] for pattern in input_patterns: input_files = glob.glob(os.path.join(input_dir, pattern)) for input_path in input_files: filename = os.path.basename(input_path) target_path = None for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: base_name = os.path.splitext(filename)[0] candidate = os.path.join(target_dir, f"{base_name}.{ext}") if os.path.exists(candidate): target_path = candidate break if target_path: self.image_pairs.append((input_path, target_path)) if not self.image_pairs: raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") print(f"Found {len(self.image_pairs)} image pairs") print(f"Extracting {patches_per_image} patches per image using {detector} detector") print(f"Total patches: {len(self.image_pairs) * patches_per_image}") def __len__(self): return len(self.image_pairs) * self.patches_per_image def _detect_salient_points(self, img_array): """Detect salient points using specified detector""" gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) h, w = gray.shape half_patch = self.patch_size // 2 if self.detector == 'harris': # Harris corner detection corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch) elif self.detector == 'fast': # FAST feature detection fast = cv2.FastFeatureDetector_create(threshold=20) keypoints = fast.detect(gray, None) corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None elif self.detector == 'shi-tomasi': # Shi-Tomasi corner detection (goodFeaturesToTrack with different params) corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch, useHarrisDetector=False) elif self.detector == 'gradient': # High-gradient regions grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_mag = np.sqrt(grad_x**2 + grad_y**2) # Find top gradient locations threshold = np.percentile(gradient_mag, 95) y_coords, x_coords = np.where(gradient_mag > threshold) if len(x_coords) > self.patches_per_image * 2: indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) x_coords = x_coords[indices] y_coords = y_coords[indices] corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None else: raise ValueError(f"Unknown detector: {self.detector}") # Fallback to random if no corners found if corners is None or len(corners) == 0: x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) # Filter valid corners (within bounds) valid_corners = [] for corner in corners: x, y = int(corner[0][0]), int(corner[0][1]) if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: valid_corners.append((x, y)) if len(valid_corners) >= self.patches_per_image: break # Fill with random if not enough while len(valid_corners) < self.patches_per_image: x = np.random.randint(half_patch, w - half_patch) y = np.random.randint(half_patch, h - half_patch) valid_corners.append((x, y)) return valid_corners def __getitem__(self, idx): img_idx = idx // self.patches_per_image patch_idx = idx % self.patches_per_image input_path, target_path = self.image_pairs[img_idx] # Load images input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') # Detect salient points (use input image for detection) input_array = np.array(input_img)[:, :, :3] # Use RGB for detection corners = self._detect_salient_points(input_array) # Extract patch at specified index x, y = corners[patch_idx] half_patch = self.patch_size // 2 # Crop patches input_patch = input_img.crop((x - half_patch, y - half_patch, x + half_patch, y + half_patch)) target_patch = target_img.crop((x - half_patch, y - half_patch, x + half_patch, y + half_patch)) if self.transform: input_patch = self.transform(input_patch) target_patch = self.transform(target_patch) return input_patch, target_patch class SimpleCNN(nn.Module): """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() if kernel_sizes is None: kernel_sizes = [3] * num_layers assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" self.kernel_sizes = kernel_sizes self.layers = nn.ModuleList() for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 if i < num_layers - 1: # Inner layers: 7→4 (RGBD output) self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) else: # Final layer: 7→1 (grayscale output) self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) def forward(self, x): # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 # Compute coordinates [0,1] then normalize to [-1,1] y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) y_coords = (y_coords - 0.5) * 2.0 # [-1,1] x_coords = (x_coords - 0.5) * 2.0 # [-1,1] # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] gray = (gray - 0.5) * 2.0 # [-1,1] # Layer 0 layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] out = self.layers[0](layer0_input) # [B,4,H,W] out = torch.tanh(out) # [-1,1] # Inner layers for i in range(1, len(self.layers)-1): layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[i](layer_input) out = torch.tanh(out) # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[-1](final_input) # [B,1,H,W] out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1] return out.expand(-1, 3, -1, -1) def generate_layer_shader(output_path, num_layers, kernel_sizes): """Generate cnn_layer.wgsl with proper layer switches""" with open(output_path, 'w') as f: f.write("// CNN layer shader - uses modular convolution snippets\n") f.write("// Supports multi-pass rendering with residual connections\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") f.write("@group(0) @binding(0) var smplr: sampler;\n") f.write("@group(0) @binding(1) var txt: texture_2d;\n\n") f.write("#include \"common_uniforms\"\n") f.write("#include \"cnn_activation\"\n") # Include necessary conv functions conv_sizes = set(kernel_sizes) for ks in sorted(conv_sizes): f.write(f"#include \"cnn_conv{ks}x{ks}\"\n") f.write("#include \"cnn_weights_generated\"\n\n") f.write("struct CNNLayerParams {\n") f.write(" layer_index: i32,\n") f.write(" blend_amount: f32,\n") f.write(" _pad: vec2,\n") f.write("};\n\n") f.write("@group(0) @binding(2) var uniforms: CommonUniforms;\n") f.write("@group(0) @binding(3) var params: CNNLayerParams;\n") f.write("@group(0) @binding(4) var original_input: texture_2d;\n\n") f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4 {\n") f.write(" var pos = array, 3>(\n") f.write(" vec2(-1.0, -1.0), vec2(3.0, -1.0), vec2(-1.0, 3.0)\n") f.write(" );\n") f.write(" return vec4(pos[i], 0.0, 1.0);\n") f.write("}\n\n") f.write("@fragment fn fs_main(@builtin(position) p: vec4) -> @location(0) vec4 {\n") f.write(" // Match PyTorch linspace: pixel_idx / (size - 1), not pixel_center / size\n") f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") f.write(" let gray = dot(original.rgb, vec3(0.2126, 0.7152, 0.0722));\n") f.write(" var result = vec4(0.0);\n\n") # Generate layer switches for layer_idx in range(num_layers): is_final = layer_idx == num_layers - 1 ks = kernel_sizes[layer_idx] conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" if layer_idx == 0: conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src" f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result);\n") f.write(f" }}\n") elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" gray, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") f.write(f" }}\n") else: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" gray, weights_layer{layer_idx});\n") f.write(f" // gray_out in [0,1] (clamped to match PyTorch training)\n") f.write(f" result = vec4(gray_out, gray_out, gray_out, 1.0);\n") f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n") f.write(f" }}\n") f.write(" return result; // [-1,1]\n") f.write("}\n") def export_weights_to_wgsl(model, output_path, kernel_sizes): """Export trained weights to WGSL format (vec4-optimized)""" with open(output_path, 'w') as f: f.write("// Auto-generated CNN weights (vec4-optimized)\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") for i, layer in enumerate(model.layers): weights = layer.weight.data.cpu().numpy() bias = layer.bias.data.cpu().numpy() out_ch, in_ch, kh, kw = weights.shape num_positions = kh * kw is_final = (i == len(model.layers) - 1) if is_final: # Final layer: 7→1, structure: array, 18> (9 pos × 2 vec4) # Input: [rgba, uv_gray_1] → 2 vec4s per position f.write(f"const weights_layer{i}: array, {num_positions * 2}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw # First vec4: [w0, w1, w2, w3] (rgba) v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)] v1.append(f"{bias[0]:.6f}") f.write(f" vec4({', '.join(v0)}),\n") f.write(f" vec4({', '.join(v1)})") f.write(",\n" if pos < num_positions-1 else "\n") f.write(");\n\n") else: # Inner layers: 7→4, structure: array, 72> (36 entries × 2 vec4) # Each filter: 2 vec4s for [rgba][uv_gray_1] inputs num_vec4s = num_positions * 4 * 2 f.write(f"const weights_layer{i}: array, {num_vec4s}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw for out_c in range(4): # First vec4: [w0, w1, w2, w3] (rgba) v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)] v1.append(f"{bias[out_c]:.6f}") idx = (pos * 4 + out_c) * 2 f.write(f" vec4({', '.join(v0)}),\n") f.write(f" vec4({', '.join(v1)})") f.write(",\n" if idx < num_vec4s-2 else "\n") f.write(");\n\n") def generate_conv_src_function(kernel_size, output_path): """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)""" k = kernel_size num_positions = k * k radius = k // 2 with open(output_path, 'a') as f: f.write(f"\n// Source layer: 7→4 channels (vec4-optimized)\n") f.write(f"// Normalizes [0,1] input to [-1,1] internally\n") f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n") f.write(f" tex: texture_2d,\n") f.write(f" samp: sampler,\n") f.write(f" uv: vec2,\n") f.write(f" resolution: vec2,\n") f.write(f" weights: array, {num_positions * 8}>\n") f.write(f") -> vec4 {{\n") f.write(f" let step = 1.0 / resolution;\n\n") # Normalize center pixel for gray channel f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") f.write(f" let gray = dot(original.rgb, vec3(0.2126, 0.7152, 0.0722));\n") f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n") f.write(f" let in1 = vec4(uv_norm, gray, 1.0);\n\n") f.write(f" var sum = vec4(0.0);\n") f.write(f" var pos = 0;\n\n") # Convolution loop f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") f.write(f" let offset = vec2(f32(dx), f32(dy)) * step;\n") f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n") # Accumulate with dot products (unrolled) f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n") f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n") f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n") f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n") f.write(f" pos += 8;\n") f.write(f" }}\n") f.write(f" }}\n\n") f.write(f" return sum;\n") f.write(f"}}\n") def generate_conv_final_function(kernel_size, output_path): """Generate cnn_conv{K}x{K}_7to1() function for final layer (vec4-optimized)""" k = kernel_size num_positions = k * k radius = k // 2 with open(output_path, 'a') as f: f.write(f"\n// Final layer: 7→1 channel (vec4-optimized)\n") f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") f.write(f"// Output clamped to [0,1] to match PyTorch training\n") f.write(f"fn cnn_conv{k}x{k}_7to1(\n") f.write(f" tex: texture_2d,\n") f.write(f" samp: sampler,\n") f.write(f" uv: vec2,\n") f.write(f" resolution: vec2,\n") f.write(f" gray: f32,\n") f.write(f" weights: array, {num_positions * 2}>\n") f.write(f") -> f32 {{\n") f.write(f" let step = 1.0 / resolution;\n") f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n") f.write(f" let in1 = vec4(uv_norm, gray, 1.0);\n\n") f.write(f" var sum = 0.0;\n") f.write(f" var pos = 0;\n\n") # Convolution loop f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") f.write(f" let offset = vec2(f32(dx), f32(dy)) * step;\n") f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n") # Accumulate with dot products f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n") f.write(f" pos += 2;\n") f.write(f" }}\n") f.write(f" }}\n\n") f.write(f" return clamp(sum, 0.0, 1.0);\n") f.write(f"}}\n") def train(args): """Main training loop""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Prepare dataset if args.patch_size: # Patch-based training (preserves natural scale) transform = transforms.Compose([ transforms.ToTensor(), ]) dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, detector=args.detector, transform=transform) else: # Full-image training (resize mode) transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) dataset = ImagePairDataset(args.input, args.target, transform=transform) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1 and args.layers > 1: kernel_sizes = kernel_sizes * args.layers # Create model model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) # Resume from checkpoint start_epoch = 0 if args.resume: if os.path.exists(args.resume): print(f"Loading checkpoint from {args.resume}...") checkpoint = torch.load(args.resume, map_location=device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] + 1 print(f"Resumed from epoch {start_epoch}") else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") # Save checkpoint if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({ 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'loss': avg_loss, 'kernel_sizes': kernel_sizes, 'num_layers': args.layers }, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") # Export weights and shader output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) # Generate _src and 7to1 variants for kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): print(f"Warning: {conv_path} not found, skipping function generation") continue with open(conv_path, 'r') as f: content = f.read() # Generate _src variant (skip 3x3, already exists) if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() # Generate 7to1 final layer with clamp (all kernel sizes) if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) print(f"Added 7to1 variant with clamp to {conv_path}") elif "clamp(sum, 0.0, 1.0)" not in content: print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed") print("Training complete!") def export_from_checkpoint(checkpoint_path, output_path=None): """Export WGSL files from checkpoint without training""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint file '{checkpoint_path}' not found") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') kernel_sizes = checkpoint['kernel_sizes'] num_layers = checkpoint['num_layers'] # Recreate model model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes) model.load_state_dict(checkpoint['model_state']) # Export weights output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"Exporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) # Generate _src and 7to1 variants for kernel sizes for ks in set(kernel_sizes): conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): print(f"Warning: {conv_path} not found, skipping function generation") continue with open(conv_path, 'r') as f: content = f.read() # Generate _src variant (skip 3x3, already exists) if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") with open(conv_path, 'r') as f: content = f.read() # Generate 7to1 final layer with clamp (all kernel sizes) if f"cnn_conv{ks}x{ks}_7to1" not in content: generate_conv_final_function(ks, conv_path) print(f"Added 7to1 variant with clamp to {conv_path}") elif "clamp(sum, 0.0, 1.0)" not in content: print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed") print("Export complete!") def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): """Run patch-based inference to match training distribution""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") sys.exit(1) if not os.path.exists(input_path): print(f"Error: Input image '{input_path}' not found") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') # Reconstruct model model = SimpleCNN( num_layers=checkpoint['num_layers'], kernel_sizes=checkpoint['kernel_sizes'] ) model.load_state_dict(checkpoint['model_state']) model.eval() # Load image print(f"Loading input image: {input_path}") img = Image.open(input_path).convert('RGBA') W, H = img.size # Tile into patches print(f"Processing {patch_size}×{patch_size} patches...") output = np.zeros((H, W, 3), dtype=np.float32) for y in range(0, H, patch_size): for x in range(0, W, patch_size): x_end = min(x + patch_size, W) y_end = min(y + patch_size, H) # Extract patch patch = img.crop((x, y, x_end, y_end)) patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w] # Inference with torch.no_grad(): out_patch = model(patch_tensor) # [1,3,h,w] # Write to output out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy() output[y:y_end, x:x_end] = out_np # Save print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) output_img = Image.fromarray((output * 255).astype(np.uint8)) output_img.save(output_path) print("Done!") def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') parser.add_argument('--input', help='Input image directory (training) or single image (inference)') parser.add_argument('--target', help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)') parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)') parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Salient point detector for patch extraction (default: harris)') args = parser.parse_args() # Inference mode if args.infer: checkpoint = args.export_only if not checkpoint: print("Error: --infer requires --export-only ") sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size) return # Export-only mode if args.export_only: export_from_checkpoint(args.export_only, args.output) return # Validate directories for training if not args.input or not args.target: print("Error: --input and --target required for training (or use --export-only)") sys.exit(1) if not os.path.isdir(args.input): print(f"Error: Input directory '{args.input}' does not exist") sys.exit(1) if not os.path.isdir(args.target): print(f"Error: Target directory '{args.target}' does not exist") sys.exit(1) train(args) if __name__ == "__main__": main()