diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 21:21:53 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 21:21:53 +0100 |
| commit | b25132133590e39967ebd0f3205123ee6628674b (patch) | |
| tree | f2c96c0535183e97e16724ca64180b3f862ea71f /training/train_cnn.py | |
| parent | 861182bd33cae0b538f2e4beeb5b56e66c8f0ff7 (diff) | |
fix: Add clamp to CNN final layer to match PyTorch training
CNN output mismatch resolved: final layer (7→1) now clamps to [0,1].
Changes:
- Add clamp(sum, 0.0, 1.0) to cnn_conv3x3_7to1 and cnn_conv5x5_7to1
- Add generate_conv_final_function() to train_cnn.py for auto-generation
- Update comments to clarify clamping behavior
- Future exports will auto-generate final layers with correct clamp
PyTorch uses torch.clamp(out, 0.0, 1.0) on final output; shaders
were missing this critical operation, causing range mismatches.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 97 |
1 files changed, 78 insertions, 19 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 6bdb15f..7a2e85a 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -199,7 +199,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" gray, weights_layer{layer_idx});\n") - f.write(f" // gray_out already in [0,1] from clipped training\n") + f.write(f" // gray_out in [0,1] (clamped to match PyTorch training)\n") f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n") f.write(f" }}\n") @@ -306,6 +306,53 @@ def generate_conv_src_function(kernel_size, output_path): f.write(f"}}\n") +def generate_conv_final_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to1() function for final layer with clamp""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Final layer: 7→1 channel (scalar output)\n") + f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") + f.write(f"// Output clamped to [0,1] to match PyTorch training\n") + f.write(f"fn cnn_conv{k}x{k}_7to1(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" gray: f32,\n") + f.write(f" weights: array<array<f32, 8>, {num_positions}>\n") + f.write(f") -> f32 {{\n") + f.write(f" let step = 1.0 / resolution;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + f.write(f" var sum = 0.0;\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1]\n\n") + + # Accumulate + f.write(f" sum += weights[pos][0] * rgbd.r;\n") + f.write(f" sum += weights[pos][1] * rgbd.g;\n") + f.write(f" sum += weights[pos][2] * rgbd.b;\n") + f.write(f" sum += weights[pos][3] * rgbd.a;\n") + f.write(f" sum += weights[pos][4] * uv_norm.x;\n") + f.write(f" sum += weights[pos][5] * uv_norm.y;\n") + f.write(f" sum += weights[pos][6] * gray;\n") + f.write(f" sum += weights[pos][7]; // Bias\n\n") + f.write(f" pos++;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return clamp(sum, 0.0, 1.0); // Match PyTorch clamp\n") + f.write(f"}}\n") + + def train(args): """Main training loop""" @@ -393,23 +440,29 @@ def train(args): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) - # Generate _src variants for kernel sizes (skip 3x3, already exists) + # Generate _src and 7to1 variants for kernel sizes for ks in set(kernel_sizes): - if ks == 3: - continue conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping _src generation") + print(f"Warning: {conv_path} not found, skipping function generation") continue - # Check if _src already exists with open(conv_path, 'r') as f: content = f.read() - if f"cnn_conv{ks}x{ks}_7to4_src" in content: - continue - generate_conv_src_function(ks, conv_path) - print(f"Added _src variant to {conv_path}") + # Generate _src variant (skip 3x3, already exists) + if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate 7to1 final layer with clamp (all kernel sizes) + if f"cnn_conv{ks}x{ks}_7to1" not in content: + generate_conv_final_function(ks, conv_path) + print(f"Added 7to1 variant with clamp to {conv_path}") + elif "clamp(sum, 0.0, 1.0)" not in content: + print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed") print("Training complete!") @@ -443,23 +496,29 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) - # Generate _src variants for kernel sizes (skip 3x3, already exists) + # Generate _src and 7to1 variants for kernel sizes for ks in set(kernel_sizes): - if ks == 3: - continue conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') if not os.path.exists(conv_path): - print(f"Warning: {conv_path} not found, skipping _src generation") + print(f"Warning: {conv_path} not found, skipping function generation") continue - # Check if _src already exists with open(conv_path, 'r') as f: content = f.read() - if f"cnn_conv{ks}x{ks}_7to4_src" in content: - continue - generate_conv_src_function(ks, conv_path) - print(f"Added _src variant to {conv_path}") + # Generate _src variant (skip 3x3, already exists) + if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content: + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate 7to1 final layer with clamp (all kernel sizes) + if f"cnn_conv{ks}x{ks}_7to1" not in content: + generate_conv_final_function(ks, conv_path) + print(f"Added 7to1 variant with clamp to {conv_path}") + elif "clamp(sum, 0.0, 1.0)" not in content: + print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed") print("Export complete!") |
