summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 21:11:05 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 21:11:05 +0100
commit7a05f4d33b611ba1e9b6c68e0d0bd67d6ea011ee (patch)
treea88109bee56197ffca8d7aacd07a878fae502d11 /training
parent2fbfc406abe5a42f45face9b07a91ec64c0d4f78 (diff)
refactor: Optimize CNN grayscale computation
Compute gray once per fragment using dot() instead of per-layer. Pass gray as f32 parameter to conv functions instead of vec4 original. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 902daa8..6bdb15f 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -172,6 +172,7 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
f.write(" let uv = p.xy / uniforms.resolution;\n")
f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
+ f.write(" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
f.write(" var result = vec4<f32>(0.0);\n\n")
# Generate layer switches
@@ -191,13 +192,13 @@ def generate_layer_shader(output_path, num_layers, kernel_sizes):
elif not is_final:
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" gray, weights_layer{layer_idx});\n")
f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
f.write(f" }}\n")
else:
f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
f.write(f" let gray_out = {conv_fn}(txt, smplr, uv, uniforms.resolution,\n")
- f.write(f" original, weights_layer{layer_idx});\n")
+ f.write(f" gray, weights_layer{layer_idx});\n")
f.write(f" // gray_out already in [0,1] from clipped training\n")
f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
@@ -270,7 +271,7 @@ def generate_conv_src_function(kernel_size, output_path):
# Normalize center pixel for gray channel
f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
- f.write(f" let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;\n")
+ f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
f.write(f" var sum = vec4<f32>(0.0);\n")