diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 444 |
1 files changed, 337 insertions, 107 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 82f0b48..16f8e7a 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -5,10 +5,15 @@ CNN Training Script for Image-to-Image Transformation Trains a convolutional neural network on multiple input/target image pairs. Usage: + # Training python3 train_cnn.py --input input_dir/ --target target_dir/ [options] + # Inference (generate ground truth) + python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png + Example: python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 + python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth """ import torch @@ -62,7 +67,8 @@ class ImagePairDataset(Dataset): def __getitem__(self, idx): input_path, target_path = self.image_pairs[idx] - input_img = Image.open(input_path).convert('RGB') + # Load RGBD input (4 channels: RGB + Depth) + input_img = Image.open(input_path).convert('RGBA') target_img = Image.open(target_path).convert('RGB') if self.transform: @@ -72,27 +78,8 @@ class ImagePairDataset(Dataset): return input_img, target_img -class CoordConv2d(nn.Module): - """Conv2d that accepts coordinate input separate from spatial patches""" - - def __init__(self, in_channels, out_channels, kernel_size, padding=0): - super().__init__() - self.conv_rgba = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) - self.coord_weights = nn.Parameter(torch.randn(out_channels, 2) * 0.01) - self.bias = nn.Parameter(torch.zeros(out_channels)) - - def forward(self, x, coords): - # x: [B, C, H, W] image - # coords: [B, 2, H, W] coordinate grid - out = self.conv_rgba(x) - B, C, H, W = out.shape - coord_contrib = torch.einsum('bchw,oc->bohw', coords, self.coord_weights) - out = out + coord_contrib + self.bias.view(1, -1, 1, 1) - return out - - class SimpleCNN(nn.Module): - """Simple CNN for image-to-image transformation""" + """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() @@ -107,30 +94,126 @@ class SimpleCNN(nn.Module): for i, kernel_size in enumerate(kernel_sizes): padding = kernel_size // 2 - if i == 0: - self.layers.append(CoordConv2d(3, 3, kernel_size, padding=padding)) + if i < num_layers - 1: + # Inner layers: 7→4 (RGBD output) + self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) else: - self.layers.append(nn.Conv2d(3, 3, kernel_size=kernel_size, padding=padding, bias=True)) - - self.use_residual = True + # Final layer: 7→1 (grayscale output) + self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) def forward(self, x): + # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape + + # Normalize RGBD to [-1,1] + x_norm = (x - 0.5) * 2.0 + + # Compute coordinates [0,1] then normalize to [-1,1] y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) - coords = torch.cat([x_coords, y_coords], dim=1) + y_coords = (y_coords - 0.5) * 2.0 # [-1,1] + x_coords = (x_coords - 0.5) * 2.0 # [-1,1] + + # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] + gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] + gray = (gray - 0.5) * 2.0 # [-1,1] + + # Layer 0 + layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] + out = self.layers[0](layer0_input) # [B,4,H,W] + out = torch.tanh(out) # [-1,1] + + # Inner layers + for i in range(1, len(self.layers)-1): + layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[i](layer_input) + out = torch.tanh(out) + + # Final layer (grayscale output) + final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[-1](final_input) # [B,1,H,W] + out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1] + return out.expand(-1, 3, -1, -1) + + +def generate_layer_shader(output_path, num_layers, kernel_sizes): + """Generate cnn_layer.wgsl with proper layer switches""" + + with open(output_path, 'w') as f: + f.write("// CNN layer shader - uses modular convolution snippets\n") + f.write("// Supports multi-pass rendering with residual connections\n") + f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") + f.write("@group(0) @binding(0) var smplr: sampler;\n") + f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n") + f.write("#include \"common_uniforms\"\n") + f.write("#include \"cnn_activation\"\n") + + # Include necessary conv functions + conv_sizes = set(kernel_sizes) + for ks in sorted(conv_sizes): + f.write(f"#include \"cnn_conv{ks}x{ks}\"\n") + f.write("#include \"cnn_weights_generated\"\n\n") - out = self.layers[0](x, coords) - out = torch.tanh(out) + f.write("struct CNNLayerParams {\n") + f.write(" layer_index: i32,\n") + f.write(" blend_amount: f32,\n") + f.write(" _pad: vec2<f32>,\n") + f.write("};\n\n") + f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n") + f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n") + f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n") + f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n") + f.write(" var pos = array<vec2<f32>, 3>(\n") + f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n") + f.write(" );\n") + f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n") + f.write("}\n\n") + f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n") + f.write(" let uv = p.xy / uniforms.resolution;\n") + f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") + f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") + f.write(" var result = vec4<f32>(0.0);\n\n") - for i in range(1, len(self.layers)): - out = self.layers[i](out) - if i < len(self.layers) - 1: - out = torch.tanh(out) + # Generate layer switches + for layer_idx in range(num_layers): + is_final = layer_idx == num_layers - 1 + ks = kernel_sizes[layer_idx] + conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" - if self.use_residual: - out = x + out * 0.3 - return out + if layer_idx == 0: + conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src" + f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n") + f.write(f" if (params.layer_index == {layer_idx}) {{\n") + f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result);\n") + f.write(f" }}\n") + elif not is_final: + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") + f.write(f" }}\n") + else: + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" // gray_out already in [0,1] from clipped training\n") + f.write(f" let original_denorm = (original + 1.0) * 0.5;\n") + f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") + f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n") + f.write(f" return blended; // [0,1]\n") + f.write(f" }}\n") + + # Add else clause for invalid layer index + if num_layers > 0: + f.write(f" else {{\n") + f.write(f" return textureSample(txt, smplr, uv);\n") + f.write(f" }}\n") + + f.write("\n // Non-final layers: denormalize for display\n") + f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n") + f.write("}\n") def export_weights_to_wgsl(model, output_path, kernel_sizes): @@ -140,82 +223,95 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): f.write("// Auto-generated CNN weights\n") f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") - layer_idx = 0 for i, layer in enumerate(model.layers): - if isinstance(layer, CoordConv2d): - # Export RGBA weights - weights = layer.conv_rgba.weight.data.cpu().numpy() - kernel_size = kernel_sizes[layer_idx] - out_ch, in_ch, kh, kw = weights.shape - num_positions = kh * kw + weights = layer.weight.data.cpu().numpy() + bias = layer.bias.data.cpu().numpy() + out_ch, in_ch, kh, kw = weights.shape + num_positions = kh * kw - f.write(f"const rgba_weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") + is_final = (i == len(model.layers) - 1) + + if is_final: + # Final layer: 7→1, structure: array<array<f32, 8>, 9> + # [w0, w1, w2, w3, w4, w5, w6, bias] + f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_positions}> = array(\n") for pos in range(num_positions): - row = pos // kw - col = pos % kw - f.write(" mat4x4<f32>(\n") - for out_c in range(min(4, out_ch)): - vals = [] - for in_c in range(min(4, in_ch)): - vals.append(f"{weights[out_c, in_c, row, col]:.6f}") - f.write(f" {', '.join(vals)},\n") - f.write(" )") - if pos < num_positions - 1: - f.write(",\n") - else: - f.write("\n") + row, col = pos // kw, pos % kw + vals = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(7)] + vals.append(f"{bias[0]:.6f}") # Append bias as 8th element + f.write(f" array<f32, 8>({', '.join(vals)})") + f.write(",\n" if pos < num_positions-1 else "\n") f.write(");\n\n") - - # Export coordinate weights - coord_w = layer.coord_weights.data.cpu().numpy() - f.write(f"const coord_weights_layer{layer_idx} = mat2x4<f32>(\n") - for c in range(2): - vals = [f"{coord_w[out_c, c]:.6f}" for out_c in range(min(4, coord_w.shape[0]))] - f.write(f" {', '.join(vals)}") - if c < 1: - f.write(",\n") - else: - f.write("\n") + else: + # Inner layers: 7→4, structure: array<array<f32, 8>, 36> + # Flattened: [pos0_ch0[7w+bias], pos0_ch1[7w+bias], ..., pos8_ch3[7w+bias]] + num_entries = num_positions * 4 + f.write(f"const weights_layer{i}: array<array<f32, 8>, {num_entries}> = array(\n") + for pos in range(num_positions): + row, col = pos // kw, pos % kw + for out_c in range(4): + vals = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(7)] + vals.append(f"{bias[out_c]:.6f}") # Append bias + idx = pos * 4 + out_c + f.write(f" array<f32, 8>({', '.join(vals)})") + f.write(",\n" if idx < num_entries-1 else "\n") f.write(");\n\n") - # Export bias - bias = layer.bias.data.cpu().numpy() - f.write(f"const bias_layer{layer_idx} = vec4<f32>(") - f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) - f.write(");\n\n") - layer_idx += 1 - elif isinstance(layer, nn.Conv2d): - # Standard conv layer - weights = layer.weight.data.cpu().numpy() - kernel_size = kernel_sizes[layer_idx] - out_ch, in_ch, kh, kw = weights.shape - num_positions = kh * kw +def generate_conv_src_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0""" - f.write(f"const weights_layer{layer_idx}: array<mat4x4<f32>, {num_positions}> = array(\n") - for pos in range(num_positions): - row = pos // kw - col = pos % kw - f.write(" mat4x4<f32>(\n") - for out_c in range(min(4, out_ch)): - vals = [] - for in_c in range(min(4, in_ch)): - vals.append(f"{weights[out_c, in_c, row, col]:.6f}") - f.write(f" {', '.join(vals)},\n") - f.write(" )") - if pos < num_positions - 1: - f.write(",\n") - else: - f.write("\n") - f.write(");\n\n") + k = kernel_size + num_positions = k * k + radius = k // 2 - # Export bias - bias = layer.bias.data.cpu().numpy() - f.write(f"const bias_layer{layer_idx} = vec4<f32>(") - f.write(", ".join([f"{b:.6f}" for b in bias[:4]])) - f.write(");\n\n") + with open(output_path, 'a') as f: + f.write(f"\n// Source layer: 7→4 channels (RGBD output)\n") + f.write(f"// Normalizes [0,1] input to [-1,1] internally\n") + f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" weights: array<array<f32, 8>, {num_positions * 4}>\n") + f.write(f") -> vec4<f32> {{\n") + f.write(f" let step = 1.0 / resolution;\n\n") + + # Normalize center pixel for gray channel + f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") + f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + + f.write(f" var sum = vec4<f32>(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n") - layer_idx += 1 + # 7-channel input + f.write(f" let inputs = array<f32, 7>(\n") + f.write(f" rgbd.r, rgbd.g, rgbd.b, rgbd.a,\n") + f.write(f" uv_norm.x, uv_norm.y, gray\n") + f.write(f" );\n\n") + + # Accumulate + f.write(f" for (var out_c = 0; out_c < 4; out_c++) {{\n") + f.write(f" let idx = pos * 4 + out_c;\n") + f.write(f" var channel_sum = weights[idx][7];\n") + f.write(f" for (var in_c = 0; in_c < 7; in_c++) {{\n") + f.write(f" channel_sum += weights[idx][in_c] * inputs[in_c];\n") + f.write(f" }}\n") + f.write(f" sum[out_c] += channel_sum;\n") + f.write(f" }}\n") + f.write(f" pos++;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") def train(args): @@ -293,32 +389,166 @@ def train(args): }, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") - # Export weights + # Export weights and shader output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") os.makedirs(os.path.dirname(output_path), exist_ok=True) export_weights_to_wgsl(model, output_path, kernel_sizes) + # Generate layer shader + shader_dir = os.path.dirname(output_path) + shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') + print(f"Generating layer shader to {shader_path}...") + generate_layer_shader(shader_path, args.layers, kernel_sizes) + + # Generate _src variants for kernel sizes (skip 3x3, already exists) + for ks in set(kernel_sizes): + if ks == 3: + continue + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + if not os.path.exists(conv_path): + print(f"Warning: {conv_path} not found, skipping _src generation") + continue + + # Check if _src already exists + with open(conv_path, 'r') as f: + content = f.read() + if f"cnn_conv{ks}x{ks}_7to4_src" in content: + continue + + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + print("Training complete!") +def export_from_checkpoint(checkpoint_path, output_path=None): + """Export WGSL files from checkpoint without training""" + + if not os.path.exists(checkpoint_path): + print(f"Error: Checkpoint file '{checkpoint_path}' not found") + sys.exit(1) + + print(f"Loading checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + kernel_sizes = checkpoint['kernel_sizes'] + num_layers = checkpoint['num_layers'] + + # Recreate model + model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes) + model.load_state_dict(checkpoint['model_state']) + + # Export weights + output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' + print(f"Exporting weights to {output_path}...") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_weights_to_wgsl(model, output_path, kernel_sizes) + + # Generate layer shader + shader_dir = os.path.dirname(output_path) + shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') + print(f"Generating layer shader to {shader_path}...") + generate_layer_shader(shader_path, num_layers, kernel_sizes) + + # Generate _src variants for kernel sizes (skip 3x3, already exists) + for ks in set(kernel_sizes): + if ks == 3: + continue + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + if not os.path.exists(conv_path): + print(f"Warning: {conv_path} not found, skipping _src generation") + continue + + # Check if _src already exists + with open(conv_path, 'r') as f: + content = f.read() + if f"cnn_conv{ks}x{ks}_7to4_src" in content: + continue + + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + + print("Export complete!") + + +def infer_from_checkpoint(checkpoint_path, input_path, output_path): + """Run inference on single image to generate ground truth""" + + if not os.path.exists(checkpoint_path): + print(f"Error: Checkpoint '{checkpoint_path}' not found") + sys.exit(1) + + if not os.path.exists(input_path): + print(f"Error: Input image '{input_path}' not found") + sys.exit(1) + + print(f"Loading checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Reconstruct model + model = SimpleCNN( + num_layers=checkpoint['num_layers'], + kernel_sizes=checkpoint['kernel_sizes'] + ) + model.load_state_dict(checkpoint['model_state']) + model.eval() + + # Load image [0,1] + print(f"Loading input image: {input_path}") + img = Image.open(input_path).convert('RGBA') + img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] + + # Inference + print("Running inference...") + with torch.no_grad(): + out = model(img_tensor) # [1,3,H,W] in [0,1] + + # Save + print(f"Saving output to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + transforms.ToPILImage()(out.squeeze(0)).save(output_path) + print("Done!") + + def main(): parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') - parser.add_argument('--input', required=True, help='Input image directory') - parser.add_argument('--target', required=True, help='Target image directory') + parser.add_argument('--input', help='Input image directory (training) or single image (inference)') + parser.add_argument('--target', help='Target image directory') parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') - parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') + parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)') parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') parser.add_argument('--resume', help='Resume from checkpoint file') + parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') + parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') args = parser.parse_args() - # Validate directories + # Inference mode + if args.infer: + checkpoint = args.export_only + if not checkpoint: + print("Error: --infer requires --export-only <checkpoint>") + sys.exit(1) + output_path = args.output or 'inference_output.png' + infer_from_checkpoint(checkpoint, args.infer, output_path) + return + + # Export-only mode + if args.export_only: + export_from_checkpoint(args.export_only, args.output) + return + + # Validate directories for training + if not args.input or not args.target: + print("Error: --input and --target required for training (or use --export-only)") + sys.exit(1) + if not os.path.isdir(args.input): print(f"Error: Input directory '{args.input}' does not exist") sys.exit(1) |
