summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 00:26:25 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 00:26:25 +0100
commitc49d828f101b435d73a76fcfc8444cf76aeda22f (patch)
tree06978626cbb614f52434c4fdd40ccb197d7064c8
parent65fa059a1e5f81901735031ae329b1313ea6679d (diff)
opt: Move invariant in1 calculation outside CNN convolution loops
The in1 vector (uv_norm, gray, 1.0) is loop-invariant and doesn't depend on dx/dy offset. Moving it outside the convolution loop eliminates redundant computation and enables better SIMD optimization. Updated both shader files and train.py code generation. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rwxr-xr-xtraining/train_cnn.py12
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv3x3.wgsl4
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv5x5.wgsl6
3 files changed, 11 insertions, 11 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 497a07b..d8522ed 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -420,7 +420,8 @@ def generate_conv_src_function(kernel_size, output_path):
# Normalize center pixel for gray channel
f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
- f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
f.write(f" var sum = vec4<f32>(0.0);\n")
f.write(f" var pos = 0;\n\n")
@@ -429,8 +430,7 @@ def generate_conv_src_function(kernel_size, output_path):
f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
- f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n")
- f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+ f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
# Accumulate with dot products (unrolled)
f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
@@ -465,7 +465,8 @@ def generate_conv_final_function(kernel_size, output_path):
f.write(f" weights: array<vec4<f32>, {num_positions * 2}>\n")
f.write(f") -> f32 {{\n")
f.write(f" let step = 1.0 / resolution;\n")
- f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
f.write(f" var sum = 0.0;\n")
f.write(f" var pos = 0;\n\n")
@@ -473,8 +474,7 @@ def generate_conv_final_function(kernel_size, output_path):
f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
- f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n")
- f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n")
# Accumulate with dot products
f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n")
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
index c032767..1a5a3e1 100644
--- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl
@@ -19,6 +19,7 @@ fn cnn_conv3x3_7to4_src(
// Normalize UV to [-1,1]
let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
var sum = vec4<f32>(0.0);
@@ -27,7 +28,6 @@ fn cnn_conv3x3_7to4_src(
for (var dx = -1; dx <= 1; dx++) {
let offset = vec2<f32>(f32(dx), f32(dy)) * step;
let rgbd = (textureSample(tex, samp, uv + offset) - .5) * 2.0;
- let in1 = vec4<f32>(uv_norm, gray, 1.0);
sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
@@ -93,6 +93,7 @@ fn cnn_conv3x3_7to1(
// Normalize UV to [-1,1]
let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
var sum = 0.0;
@@ -101,7 +102,6 @@ fn cnn_conv3x3_7to1(
for (var dx = -1; dx <= 1; dx++) {
let offset = vec2<f32>(f32(dx), f32(dy)) * step;
let rgbd = textureSample(tex, samp, uv + offset);
- let in1 = vec4<f32>(uv_norm, gray, 1.0);
sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
pos += 2;
diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
index 119930f..ba2a4b7 100644
--- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
@@ -12,6 +12,7 @@ fn cnn_conv5x5_7to4(
) -> vec4<f32> {
let step = 1.0 / resolution;
let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
var sum = vec4<f32>(0.0);
var pos = 0;
@@ -20,7 +21,6 @@ fn cnn_conv5x5_7to4(
for (var dx = -2; dx <= 2; dx++) {
let offset = vec2<f32>(f32(dx), f32(dy)) * step;
let rgbd = textureSample(tex, samp, uv + offset);
- let in1 = vec4<f32>(uv_norm, gray, 1.0);
sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
@@ -47,6 +47,7 @@ fn cnn_conv5x5_7to1(
) -> f32 {
let step = 1.0 / resolution;
let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
var sum = 0.0;
var pos = 0;
@@ -55,7 +56,6 @@ fn cnn_conv5x5_7to1(
for (var dx = -2; dx <= 2; dx++) {
let offset = vec2<f32>(f32(dx), f32(dy)) * step;
let rgbd = textureSample(tex, samp, uv + offset);
- let in1 = vec4<f32>(uv_norm, gray, 1.0);
sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
pos += 2;
@@ -79,6 +79,7 @@ fn cnn_conv5x5_7to4_src(
let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;
let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
var sum = vec4<f32>(0.0);
var pos = 0;
@@ -87,7 +88,6 @@ fn cnn_conv5x5_7to4_src(
for (var dx = -2; dx <= 2; dx++) {
let offset = vec2<f32>(f32(dx), f32(dy)) * step;
let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;
- let in1 = vec4<f32>(uv_norm, gray, 1.0);
sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);