diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 20:00:26 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 20:00:26 +0100 |
| commit | 2a2369e38fbe1bf8261968dafc88dac73bdda7ce (patch) | |
| tree | bf4505ca53e501af9dab61b172ecc60f3671d79c | |
| parent | 3153b55135788c8a7691929f913a2c9b96a44154 (diff) | |
fix: CNN training normalization pipeline consistency
**Training changes:**
- Final layer now outputs [0,1] directly with torch.clamp()
- Removed denormalization step (was converting [-1,1] to [0,1])
- Network learns [0,1] output natively
**Shader generation fixes:**
- Layer 0 uses _src variant (5 params, normalizes [0,1] input internally)
- Removed pre-normalization of input texture (handled by _src)
- Final layer blending: gray_out already [0,1], no denormalization needed
- Added generate_conv_src_function() for all kernel sizes
- Auto-generates _src variants when exporting (skips if exists)
**Cleanup:**
- Removed obsolete 4-channel functions from cnn_conv5x5.wgsl
- Keep only 7-channel variants (_7to4, _7to1, _7to4_src)
**Normalization flow:**
[0,1] texture → _src normalizes to [-1,1] → tanh [-1,1] → ... → final conv [0,1] clipped
handoff(Claude): CNN normalization pipeline fixed and consistent with training
| -rwxr-xr-x | training/train_cnn.py | 122 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv3x3.wgsl | 2 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv5x5.wgsl | 56 |
3 files changed, 109 insertions, 71 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 2250e9c..3312768 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -126,10 +126,8 @@ class SimpleCNN(nn.Module): # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) - out = self.layers[-1](final_input) # [B,1,H,W] in [-1,1] - - # Denormalize to [0,1] and expand to RGB for visualization - out = (out + 1.0) * 0.5 + out = self.layers[-1](final_input) # [B,1,H,W] + out = torch.clamp(out, 0.0, 1.0) # Clip to [0,1] return out.expand(-1, 3, -1, -1) @@ -167,8 +165,6 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write("}\n\n") f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n") f.write(" let uv = p.xy / uniforms.resolution;\n") - f.write(" let input_raw = textureSample(txt, smplr, uv);\n") - f.write(" let input = (input_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") f.write(" var result = vec4<f32>(0.0);\n\n") @@ -180,11 +176,12 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" if layer_idx == 0: - f.write(f" // Layer 0: 7→4 (RGBD output)\n") + conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src" + f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n") f.write(f" if (params.layer_index == {layer_idx}) {{\n") - f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") - f.write(f" original, weights_layer{layer_idx});\n") - f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") + f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution,\n") + f.write(f" weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result);\n") f.write(f" }}\n") elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") @@ -196,18 +193,21 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") f.write(f" original, weights_layer{layer_idx});\n") - f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0); // Keep in [-1,1]\n") + f.write(f" // gray_out already in [0,1] from clipped training\n") + f.write(f" let original_denorm = (original + 1.0) * 0.5;\n") + f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") + f.write(f" let blended = mix(original_denorm, result, params.blend_amount);\n") + f.write(f" return blended; // [0,1]\n") f.write(f" }}\n") # Add else clause for invalid layer index if num_layers > 0: f.write(f" else {{\n") - f.write(f" result = input;\n") + f.write(f" return textureSample(txt, smplr, uv);\n") f.write(f" }}\n") - f.write("\n // Blend with ORIGINAL input from layer 0 and denormalize for display\n") - f.write(" let blended = mix(original, result, params.blend_amount);\n") - f.write(" return (blended + 1.0) * 0.5; // Denormalize to [0,1] for display\n") + f.write("\n // Non-final layers: denormalize for display\n") + f.write(" return (result + 1.0) * 0.5; // [-1,1] → [0,1]\n") f.write("}\n") @@ -253,6 +253,62 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): f.write(");\n\n") +def generate_conv_src_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Source layer: 7→4 channels (RGBD output)\n") + f.write(f"// Normalizes [0,1] input to [-1,1] internally\n") + f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" weights: array<array<f32, 8>, {num_positions * 4}>\n") + f.write(f") -> vec4<f32> {{\n") + f.write(f" let step = 1.0 / resolution;\n\n") + + # Normalize center pixel for gray channel + f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") + f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + + f.write(f" var sum = vec4<f32>(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n") + + # 7-channel input + f.write(f" let inputs = array<f32, 7>(\n") + f.write(f" rgbd.r, rgbd.g, rgbd.b, rgbd.a,\n") + f.write(f" uv_norm.x, uv_norm.y, gray\n") + f.write(f" );\n\n") + + # Accumulate + f.write(f" for (var out_c = 0; out_c < 4; out_c++) {{\n") + f.write(f" let idx = pos * 4 + out_c;\n") + f.write(f" var channel_sum = weights[idx][7];\n") + f.write(f" for (var in_c = 0; in_c < 7; in_c++) {{\n") + f.write(f" channel_sum += weights[idx][in_c] * inputs[in_c];\n") + f.write(f" }}\n") + f.write(f" sum[out_c] += channel_sum;\n") + f.write(f" }}\n") + f.write(f" pos++;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + def train(args): """Main training loop""" @@ -340,6 +396,24 @@ def train(args): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, args.layers, kernel_sizes) + # Generate _src variants for kernel sizes (skip 3x3, already exists) + for ks in set(kernel_sizes): + if ks == 3: + continue + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + if not os.path.exists(conv_path): + print(f"Warning: {conv_path} not found, skipping _src generation") + continue + + # Check if _src already exists + with open(conv_path, 'r') as f: + content = f.read() + if f"cnn_conv{ks}x{ks}_7to4_src" in content: + continue + + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + print("Training complete!") @@ -372,6 +446,24 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print(f"Generating layer shader to {shader_path}...") generate_layer_shader(shader_path, num_layers, kernel_sizes) + # Generate _src variants for kernel sizes (skip 3x3, already exists) + for ks in set(kernel_sizes): + if ks == 3: + continue + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + if not os.path.exists(conv_path): + print(f"Warning: {conv_path} not found, skipping _src generation") + continue + + # Check if _src already exists + with open(conv_path, 'r') as f: + content = f.read() + if f"cnn_conv{ks}x{ks}_7to4_src" in content: + continue + + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + print("Export complete!") diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl index ebb87b5..96ddf5b 100644 --- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl @@ -144,5 +144,5 @@ fn cnn_conv3x3_7to1( } } - return sum; // Output in [-1,1], needs denormalization + return sum; // Output in [-1,1] } diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl index bfb4ebb..5136740 100644 --- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl @@ -1,57 +1,3 @@ -// 5x5 convolution with 25 samples -// Applies mat4 weights per sample - -fn cnn_conv5x5( - tex: texture_2d<f32>, - samp: sampler, - uv: vec2<f32>, - resolution: vec2<f32>, - weights: array<mat4x4<f32>, 25>, - bias: vec4<f32> -) -> vec4<f32> { - let step = 1.0 / resolution; - var sum = bias; - var idx = 0; - - for (var dy = -2; dy <= 2; dy++) { - for (var dx = -2; dx <= 2; dx++) { - let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let sample = textureSample(tex, samp, uv + offset); - sum += weights[idx] * sample; - idx++; - } - } - - return sum; -} - -fn cnn_conv5x5_with_coord( - tex: texture_2d<f32>, - samp: sampler, - uv: vec2<f32>, - resolution: vec2<f32>, - rgba_weights: array<mat4x4<f32>, 25>, - coord_weights: mat2x4<f32>, - bias: vec4<f32> -) -> vec4<f32> { - let step = 1.0 / resolution; - var sum = bias; - - sum += coord_weights * uv; - - var idx = 0; - for (var dy = -2; dy <= 2; dy++) { - for (var dx = -2; dx <= 2; dx++) { - let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgba = textureSample(tex, samp, uv + offset); - sum += rgba_weights[idx] * rgba; - idx++; - } - } - - return sum; -} - // 5×5 variant for 7→4 channels (RGBD output) // Assumes 'tex' and 'original' are already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally @@ -135,5 +81,5 @@ fn cnn_conv5x5_7to1( } } - return sum; + return sum; // Output in [-1,1] } |
