#!/usr/bin/env python3 """ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: python3 train_cnn.py --input input_dir/ --target target_dir/ [options] Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import sys import argparse import glob class ImagePairDataset(Dataset): """Dataset for loading matching input/target image pairs""" def __init__(self, input_dir, target_dir, transform=None): self.input_dir = input_dir self.target_dir = target_dir self.transform = transform # Find all images in input directory input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] self.image_pairs = [] for pattern in input_patterns: input_files = glob.glob(os.path.join(input_dir, pattern)) for input_path in input_files: filename = os.path.basename(input_path) # Try to find matching target with same name but any supported extension target_path = None for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: base_name = os.path.splitext(filename)[0] candidate = os.path.join(target_dir, f"{base_name}.{ext}") if os.path.exists(candidate): target_path = candidate break if target_path: self.image_pairs.append((input_path, target_path)) if not self.image_pairs: raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") print(f"Found {len(self.image_pairs)} matching image pairs") def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] # Load RGBD input (4 channels: RGB + Depth) input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') if self.transform: input_img = self.transform(input_img) target_img = self.transform(target_img) return input_img, target_img class SimpleCNN(nn.Module): """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() if kernel_sizes is None: kernel_sizes = [3] * num_layers assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" self.kernel_sizes = kernel_sizes self.layers = nn.ModuleList() for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 if i < num_layers - 1: # Inner layers: 7→4 (RGBD output) self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) else: # Final layer: 7→1 (grayscale output) self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) def forward(self, x): # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 # Compute coordinates [0,1] then normalize to [-1,1] y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) y_coords = (y_coords - 0.5) * 2.0 # [-1,1] x_coords = (x_coords - 0.5) * 2.0 # [-1,1] # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] gray = (gray - 0.5) * 2.0 # [-1,1] # Layer 0 layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] out = self.layers[0](layer0_input) # [B,4,H,W] out = torch.tanh(out) # [-1,1] # Inner layers for i in range(1, len(self.layers)-1): layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[i](layer_input) out = torch.tanh(out) # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1] # Denormalize to [0,1] and expand to RGB for visualization out = (out + 1.0) * 0.5 return out.expand(-1, 3, -1, -1) def generate_layer_shader(output_path, num_layers, kernel_sizes): """Generate cnn_layer.wgsl with proper layer switches""" with open(output_path, 'w') as f: f.write("// CNN layer shader - uses modular convolution snippets\n") f.write("// Supports multi-pass rendering with residual connections\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") f.write("@group(0) @binding(0) var smplr: sampler;\n") f.write("@group(0) @binding(1) var txt: texture_2d;\n\n") f.write("#include \"common_uniforms\"\n") f.write("#include \"cnn_activation\"\n") # Include necessary conv functions conv_sizes = set(kernel_sizes) for ks in sorted(conv_sizes): f.write(f"#include \"cnn_conv{ks}x{ks}\"\n") f.write("#include \"cnn_weights_generated\"\n\n") f.write("struct CNNLayerParams {\n") f.write(" layer_index: i32,\n") f.write(" blend_amount: f32,\n") f.write(" _pad: vec2,\n") f.write("};\n\n") f.write("@group(0) @binding(2) var uniforms: CommonUniforms;\n") f.write("@group(0) @binding(3) var params: CNNLayerParams;\n") f.write("@group(0) @binding(4) var original_input: texture_2d;\n\n") f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4 {\n") f.write(" var pos = array, 3>(\n") f.write(" vec2(-1.0, -1.0), vec2(3.0, -1.0), vec2(-1.0, 3.0)\n") f.write(" );\n") f.write(" return vec4(pos[i], 0.0, 1.0);\n") f.write("}\n\n") f.write("@fragment fn fs_main(@builtin(position) p: vec4) -> @location(0) vec4 {\n") f.write(" let uv = p.xy / uniforms.resolution;\n") f.write(" let input = textureSample(txt, smplr, uv);\n") f.write(" let original = textureSample(original_input, smplr, uv);\n") f.write(" var result = vec4(0.0);\n\n") # Generate layer switches for layer_idx in range(num_layers): is_final = layer_idx == num_layers - 1 if layer_idx == 0: f.write(f" // Layer 0: 7→4 (RGBD output)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") f.write(f" // Denormalize to [0,1] for texture storage\n") f.write(f" result = (result + 1.0) * 0.5;\n") f.write(f" }}\n") elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Output in [-1,1]\n") f.write(f" // Denormalize to [0,1] for texture storage\n") f.write(f" result = (result + 1.0) * 0.5;\n") f.write(f" }}\n") else: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") f.write(f" // Denormalize from [-1,1] to [0,1]\n") f.write(f" let gray_01 = (gray_out + 1.0) * 0.5;\n") f.write(f" result = vec4(gray_01, gray_01, gray_01, 1.0); // Expand to RGB\n") f.write(f" }}\n") # Add else clause for invalid layer index if num_layers > 0: f.write(f" else {{\n") f.write(f" result = input;\n") f.write(f" }}\n") f.write("\n // Blend with ORIGINAL input from layer 0\n") f.write(" return mix(original, result, params.blend_amount);\n") f.write("}\n") def export_weights_to_wgsl(model, output_path, kernel_sizes): """Export trained weights to WGSL format""" with open(output_path, 'w') as f: f.write("// Auto-generated CNN weights\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") for i, layer in enumerate(model.layers): weights = layer.weight.data.cpu().numpy() bias = layer.bias.data.cpu().numpy() out_ch, in_ch, kh, kw = weights.shape num_positions = kh * kw is_final = (i == len(model.layers) - 1) if is_final: # Final layer: 7→1, structure: array, 9> # [w0, w1, w2, w3, w4, w5, w6, bias] f.write(f"const weights_layer{i}: array, {num_positions}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)] vals.append(f"{bias[0]:.6f}") # Append bias as 8th element f.write(f" array({', '.join(vals)})") f.write(",\n" if pos < num_positions-1 else "\n") f.write(");\n\n") else: # Inner layers: 7→4, structure: array, 36> # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]] num_entries = num_positions * 4 f.write(f"const weights_layer{i}: array, {num_entries}> = array(\n") for pos in range(num_positions): row, col = pos // kw, pos % kw for out_c in range(4): vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)] vals.append(f"{bias[out_c]:.6f}") # Append bias idx = pos * 4 + out_c f.write(f" array({', '.join(vals)})") f.write(",\n" if idx < num_entries-1 else "\n") f.write(");\n\n") def train(args): """Main training loop""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Prepare dataset transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) dataset = ImagePairDataset(args.input, args.target, transform=transform) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1 and args.layers > 1: kernel_sizes = kernel_sizes * args.layers # Create model model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) # Resume from checkpoint start_epoch = 0 if args.resume: if os.path.exists(args.resume): print(f"Loading checkpoint from {args.resume}...") checkpoint = torch.load(args.resume, map_location=device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] + 1 print(f"Resumed from epoch {start_epoch}") else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") # Save checkpoint if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({ 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'loss': avg_loss, 'kernel_sizes': kernel_sizes, 'num_layers': args.layers }, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") # Export weights and shader output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) print("Training complete!") def export_from_checkpoint(checkpoint_path, output_path=None): """Export WGSL files from checkpoint without training""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint file '{checkpoint_path}' not found") sys.exit(1) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') kernel_sizes = checkpoint['kernel_sizes'] num_layers = checkpoint['num_layers'] # Recreate model model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes) model.load_state_dict(checkpoint['model_state']) # Export weights output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"Exporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) # Generate layer shader shader_dir = os.path.dirname(output_path) shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) print("Export complete!") def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') parser.add_argument('--input', help='Input image directory') parser.add_argument('--target', help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') args = parser.parse_args() # Export-only mode if args.export_only: export_from_checkpoint(args.export_only, args.output) return # Validate directories for training if not args.input or not args.target: print("Error: --input and --target required for training (or use --export-only)") sys.exit(1) if not os.path.isdir(args.input): print(f"Error: Input directory '{args.input}' does not exist") sys.exit(1) if not os.path.isdir(args.target): print(f"Error: Target directory '{args.target}' does not exist") sys.exit(1) train(args) if __name__ == "__main__": main()