diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 21:11:05 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 21:11:05 +0100 |
| commit | 7a05f4d33b611ba1e9b6c68e0d0bd67d6ea011ee (patch) | |
| tree | a88109bee56197ffca8d7aacd07a878fae502d11 /training | |
| parent | 2fbfc406abe5a42f45face9b07a91ec64c0d4f78 (diff) | |
refactor: Optimize CNN grayscale computation
Compute gray once per fragment using dot() instead of per-layer.
Pass gray as f32 parameter to conv functions instead of vec4 original.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 902daa8..6bdb15f 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -172,6 +172,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): f.write(" let uv = p.xy / uniforms.resolution;\n") f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") + f.write(" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n") f.write(" var result = vec4<f32>(0.0);\n\n") # Generate layer switches @@ -191,13 +192,13 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes): elif not is_final: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") - f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" gray, weights_layer{layer_idx});\n") f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") f.write(f" }}\n") else: f.write(f" else if (params.layer_index == {layer_idx}) {{\n") f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n") - f.write(f" original, weights_layer{layer_idx});\n") + f.write(f" gray, weights_layer{layer_idx});\n") f.write(f" // gray_out already in [0,1] from clipped training\n") f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n") @@ -270,7 +271,7 @@ def generate_conv_src_function(kernel_size, output_path): # Normalize center pixel for gray channel f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") - f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n") + f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n") f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") f.write(f" var sum = vec4<f32>(0.0);\n") |
