#!/usr/bin/env python3 """ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: # Training python3 train_cnn.py --input input_dir/ --target target_dir/ [options] # Inference (generate ground truth) python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import sys import argparse import glob class ImagePairDataset(Dataset): """Dataset for loading matching input/target image pairs""" def __init__(self, input_dir, target_dir, transform=None): self.input_dir = input_dir self.target_dir = target_dir self.transform = transform # Find all images in input directory input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] self.image_pairs = [] for pattern in input_patterns: input_files = glob.glob(os.path.join(input_dir, pattern)) for input_path in input_files: filename = os.path.basename(input_path) # Try to find matching target with same name but any supported extension target_path = None for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: base_name = os.path.splitext(filename)[0] candidate = os.path.join(target_dir, f"{base_name}.{ext}") if os.path.exists(candidate): target_path = candidate break if target_path: self.image_pairs.append((input_path, target_path)) if not self.image_pairs: raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") print(f"Found {len(self.image_pairs)} matching image pairs") def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] # Load RGBD input (4 channels: RGB + Depth) input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') if self.transform: input_img = self.transform(input_img) target_img = self.transform(target_img) return input_img, target_img class SimpleCNN(nn.Module): """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() if kernel_sizes is None: kernel_sizes = [3] * num_layers assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" self.kernel_sizes = kernel_sizes self.layers = nn.ModuleList() for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 if i < num_layers - 1: # Inner layers: 7→4 (RGBD output) self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) else: # Final layer: 7→1 (grayscale output) self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) def forward(self, x): # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 # Compute coordinates [0,1] then normalize to [-1,1] y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) y_coords = (y_coords - 0.5) * 2.0 # [-1,1] x_coords = (x_coords - 0.5) * 2.0 # [-1,1] # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] gray = (gray - 0.5) * 2.0 # [-1,1] # Layer 0 layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] out = self.layers[0](layer0_input) # [B,4,H,W] out = torch.tanh(out) # [-1,1] # Inner layers for i in range(1, len(self.layers)-1): layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[i](layer_input) out = torch.tanh(out) # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[-1](final_input) # [B,1,H,W] out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1] return out.expand(-1, 3, -1, -1) def generate_layer_shader(output_path, num_layers, kernel_sizes): """Generate cnn_layer.wgsl with proper layer switches""" with open(output_path, 'w') as f: f.write("// CNN layer shader - uses modular convolution snippets\n") f.write("// Supports multi-pass rendering with residual connections\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") f.write("@group(0) @binding(0) var smplr: sampler;\n") f.write("@group(0) @binding(1) var txt: texture_2d;\n\n") f.write("#include \"common_uniforms\"\n") f.write("#include \"cnn_activation\"\n") # Include necessary conv functions conv_sizes = set(kernel_sizes) for ks in sorted(conv_sizes): f.write(f"#include \"cnn_conv{ks}x{ks}\"\n") f.write("#include \"cnn_weights_generated\"\n\n") f.write("struct CNNLayerParams {\n") f.write(" layer_index: i32,\n") f.write(" blend_amount: f32,\n") f.write(" _pad: vec2,\n") f.write("};\n\n") f.write("@group(0) @binding(2) var uniforms: CommonUniforms;\n") f.write("@group(0) @binding(3) var params: CNNLayerParams;\n") f.write("@group(0) @binding(4) var original_input: texture_2d;\n\n") f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4 {\n") f.write(" var pos = array, 3>(\n") f.write(" vec2(-1.0, -1.0), vec2(3.0, -1.0), vec2(-1.0, 3.0)\n") f.write(" );\n") f.write(" return vec4(pos[i], 0.0, 1.0);\n") f.write("}\n\n") f.write("@fragment fn fs_main(@builtin(position) p: vec4) -> @location(0) vec4 {\n") f.write(" let uv = p.xy / uniforms.resolution;\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") f.write(" var result = vec4(0.0);\n\n") # Generate layer switches for layer_idx in range(num_layers): is_final = layer_idx == num_layers - 1 ks = kernel_sizes[layer_idx] conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" if layer_idx == 0: conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src" f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result);\n") f.write(f" }}\n") elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") f.write(f" }}\n") else: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" // gray_out already in [0,1] from clipped training\n") f.write(f" let original_denorm = (original + 1.0) * 0.5;\n") f.write(f" result = vec4(gray_out, gray_out, gray_out, 1.0);\n") f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n") f.write(f" return blended; // [0,1]\n") f.write(f" }}\n") # Add else clause for invalid layer index if num_layers > 0: f.write(f" else {{\n") f.write(f" return textureSample(txt, smplr, uv);\n") f.write(f" }}\n") f.write("\n // Non-final layers: denormalize for display\n") f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n") f.write("}\n") def export_weights_to_wgsl(model, output_path, kernel_sizes): """Export trained weights to WGSL format""" with open(output_path, 'w') as f: f.write("// Auto-generated CNN weights\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") for i, layer in enumerate(model.layers): weights = layer.weight.data.cpu().numpy() bias = layer.bias.data.cpu().numpy() out_ch, in_ch, kh, kw = weights.shape num_positions = kh * kw is_final = (i == len(model.layers) - 1) if is_final: # Final layer: 7→1, structure: array, 9> # [w0, w1, w2, w3, w4, w5, w6, bias] f.write(f"const weights_layer{i}: array, {num_positions}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)] vals.append(f"{bias[0]:.6f}") # Append bias as 8th element f.write(f" array({', '.join(vals)})") f.write(",\n" if pos < num_positions-1 else "\n") f.write(");\n\n") else: # Inner layers: 7→4, structure: array, 36> # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]] num_entries = num_positions * 4 f.write(f"const weights_layer{i}: array, {num_entries}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw for out_c in range(4): vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)] vals.append(f"{bias[out_c]:.6f}") # Append bias idx = pos * 4 + out_c f.write(f" array({', '.join(vals)})") f.write(",\n" if idx < num_entries-1 else "\n") f.write(");\n\n") def generate_conv_src_function(kernel_size, output_path): """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0""" k = kernel_size num_positions = k * k radius = k // 2 with open(output_path, 'a') as f: f.write(f"\n// Source layer: 7→4 channels (RGBD output)\n") f.write(f"// Normalizes [0,1] input to [-1,1] internally\n") f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n") f.write(f" tex: texture_2d,\n") f.write(f" samp: sampler,\n") f.write(f" uv: vec2,\n") f.write(f" resolution: vec2,\n") f.write(f" weights: array, {num_positions * 4}>\n") f.write(f") -> vec4 {{\n") f.write(f" let step = 1.0 / resolution;\n\n") # Normalize center pixel for gray channel f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n") f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") f.write(f" var sum = vec4(0.0);\n") f.write(f" var pos = 0;\n\n") # Convolution loop f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") f.write(f" let offset = vec2(f32(dx), f32(dy)) * step;\n") f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n") # 7-channel input f.write(f" let inputs = array(\n") f.write(f" rgbd.r, rgbd.g, rgbd.b, rgbd.a,\n") f.write(f" uv_norm.x, uv_norm.y, gray\n") f.write(f" );\n\n") # Accumulate f.write(f" for (var out_c = 0; out_c < 4; out_c++) {{\n") f.write(f" let idx = pos * 4 + out_c;\n") f.write(f" var channel_sum = weights[idx][7];\n") f.write(f" for (var in_c = 0; in_c < 7; in_c++) {{\n") f.write(f" channel_sum += weights[idx][in_c] * inputs[in_c];\n") f.write(f" }}\n") f.write(f" sum[out_c] += channel_sum;\n") f.write(f" }}\n") f.write(f" pos++;\n") f.write(f" }}\n") f.write(f" }}\n\n") f.write(f" return sum;\n") f.write(f"}}\n") def train(args): """Main training loop""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Prepare dataset transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) dataset = ImagePairDataset(args.input, args.target, transform=transform) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1 and args.layers > 1: kernel_sizes = kernel_sizes * args.layers # Create model model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) # Resume from checkpoint start_epoch = 0 if args.resume: if os.path.exists(args.resume): print(f"Loading checkpoint from {args.resume}...") checkpoint = torch.load(args.resume, map_location=device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] + 1 print(f"Resumed from epoch {start_epoch}") else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") # Save checkpoint if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({ 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'loss': avg_loss, 'kernel_sizes': kernel_sizes, 'num_layers': args.layers }, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") # Export weights and shader output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) # Generate _src variants for kernel sizes (skip 3x3, already exists) for ks in set(kernel_sizes): if ks == 3: continue conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): print(f"Warning: {conv_path} not found, skipping _src generation") continue # Check if _src already exists with open(conv_path, 'r') as f: content = f.read() if f"cnn_conv{ks}x{ks}_7to4_src" in content: continue generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") print("Training complete!") def export_from_checkpoint(checkpoint_path, output_path=None): """Export WGSL files from checkpoint without training""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint file '{checkpoint_path}' not found") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') kernel_sizes = checkpoint['kernel_sizes'] num_layers = checkpoint['num_layers'] # Recreate model model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes) model.load_state_dict(checkpoint['model_state']) # Export weights output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"Exporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) # Generate _src variants for kernel sizes (skip 3x3, already exists) for ks in set(kernel_sizes): if ks == 3: continue conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): print(f"Warning: {conv_path} not found, skipping _src generation") continue # Check if _src already exists with open(conv_path, 'r') as f: content = f.read() if f"cnn_conv{ks}x{ks}_7to4_src" in content: continue generate_conv_src_function(ks, conv_path) print(f"Added _src variant to {conv_path}") print("Export complete!") def infer_from_checkpoint(checkpoint_path, input_path, output_path): """Run inference on single image to generate ground truth""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") sys.exit(1) if not os.path.exists(input_path): print(f"Error: Input image '{input_path}' not found") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') # Reconstruct model model = SimpleCNN( num_layers=checkpoint['num_layers'], kernel_sizes=checkpoint['kernel_sizes'] ) model.load_state_dict(checkpoint['model_state']) model.eval() # Load image [0,1] print(f"Loading input image: {input_path}") img = Image.open(input_path).convert('RGBA') img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] # Inference print("Running inference...") with torch.no_grad(): out = model(img_tensor) # [1,3,H,W] in [0,1] # Save print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) transforms.ToPILImage()(out.squeeze(0)).save(output_path) print("Done!") def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') parser.add_argument('--input', help='Input image directory (training) or single image (inference)') parser.add_argument('--target', help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') args = parser.parse_args() # Inference mode if args.infer: checkpoint = args.export_only if not checkpoint: print("Error: --infer requires --export-only ") sys.exit(1) output_path = args.output or 'inference_output.png' infer_from_checkpoint(checkpoint, args.infer, output_path) return # Export-only mode if args.export_only: export_from_checkpoint(args.export_only, args.output) return # Validate directories for training if not args.input or not args.target: print("Error: --input and --target required for training (or use --export-only)") sys.exit(1) if not os.path.isdir(args.input): print(f"Error: Input directory '{args.input}' does not exist") sys.exit(1) if not os.path.isdir(args.target): print(f"Error: Target directory '{args.target}' does not exist") sys.exit(1) train(args) if __name__ == "__main__": main()