summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 21:21:53 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 21:21:53 +0100
commitb25132133590e39967ebd0f3205123ee6628674b (patch)
treef2c96c0535183e97e16724ca64180b3f862ea71f /training/train_cnn.py
parent861182bd33cae0b538f2e4beeb5b56e66c8f0ff7 (diff)
fix: Add clamp to CNN final layer to match PyTorch training
CNN output mismatch resolved: final layer (7→1) now clamps to [0,1]. Changes: - Add clamp(sum, 0.0, 1.0) to cnn_conv3x3_7to1 and cnn_conv5x5_7to1 - Add generate_conv_final_function() to train_cnn.py for auto-generation - Update comments to clarify clamping behavior - Future exports will auto-generate final layers with correct clamp PyTorch uses torch.clamp(out, 0.0, 1.0) on final output; shaders were missing this critical operation, causing range mismatches. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py97
1 files changed, 78 insertions, 19 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 6bdb15f..7a2e85a 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -199,7 +199,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
f.write(f" gray, weights_layer{layer_idx});\n")
- f.write(f" // gray_out already in [0,1] from clipped training\n")
+ f.write(f" // gray_out in [0,1] (clamped to match PyTorch training)\n")
f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
f.write(f" }}\n")
@@ -306,6 +306,53 @@ def generate_conv_src_function(kernel_size, output_path):
f.write(f"}}\n")
+def generate_conv_final_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to1() function for final layer with clamp"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Final layer: 7→1 channel (scalar output)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"// Output clamped to [0,1] to match PyTorch training\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to1(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<array<f32, 8>, {num_positions}>\n")
+ f.write(f") -> f32 {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" var sum = 0.0;\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1]\n\n")
+
+ # Accumulate
+ f.write(f" sum += weights[pos][0] * rgbd.r;\n")
+ f.write(f" sum += weights[pos][1] * rgbd.g;\n")
+ f.write(f" sum += weights[pos][2] * rgbd.b;\n")
+ f.write(f" sum += weights[pos][3] * rgbd.a;\n")
+ f.write(f" sum += weights[pos][4] * uv_norm.x;\n")
+ f.write(f" sum += weights[pos][5] * uv_norm.y;\n")
+ f.write(f" sum += weights[pos][6] * gray;\n")
+ f.write(f" sum += weights[pos][7]; // Bias\n\n")
+ f.write(f" pos++;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return clamp(sum, 0.0, 1.0); // Match PyTorch clamp\n")
+ f.write(f"}}\n")
+
+
def train(args):
"""Main training loop"""
@@ -393,23 +440,29 @@ def train(args):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, args.layers, kernel_sizes)
- # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ # Generate _src and 7to1 variants for kernel sizes
for ks in set(kernel_sizes):
- if ks == 3:
- continue
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping _src generation")
+ print(f"Warning: {conv_path} not found, skipping function generation")
continue
- # Check if _src already exists
with open(conv_path, 'r') as f:
content = f.read()
- if f"cnn_conv{ks}x{ks}_7to4_src" in content:
- continue
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
+ # Generate _src variant (skip 3x3, already exists)
+ if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer with clamp (all kernel sizes)
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant with clamp to {conv_path}")
+ elif "clamp(sum, 0.0, 1.0)" not in content:
+ print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed")
print("Training complete!")
@@ -443,23 +496,29 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print(f"Generating layer shader to {shader_path}...")
generate_layer_shader(shader_path, num_layers, kernel_sizes)
- # Generate _src variants for kernel sizes (skip 3x3, already exists)
+ # Generate _src and 7to1 variants for kernel sizes
for ks in set(kernel_sizes):
- if ks == 3:
- continue
conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
if not os.path.exists(conv_path):
- print(f"Warning: {conv_path} not found, skipping _src generation")
+ print(f"Warning: {conv_path} not found, skipping function generation")
continue
- # Check if _src already exists
with open(conv_path, 'r') as f:
content = f.read()
- if f"cnn_conv{ks}x{ks}_7to4_src" in content:
- continue
- generate_conv_src_function(ks, conv_path)
- print(f"Added _src variant to {conv_path}")
+ # Generate _src variant (skip 3x3, already exists)
+ if ks != 3 and f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer with clamp (all kernel sizes)
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant with clamp to {conv_path}")
+ elif "clamp(sum, 0.0, 1.0)" not in content:
+ print(f"Warning: {conv_path} has 7to1 but missing clamp - manual fix needed")
print("Export complete!")