diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 21:11:05 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 21:11:05 +0100 |
| commit | 7a05f4d33b611ba1e9b6c68e0d0bd67d6ea011ee (patch) | |
| tree | a88109bee56197ffca8d7aacd07a878fae502d11 /workspaces | |
| parent | 2fbfc406abe5a42f45face9b07a91ec64c0d4f78 (diff) | |
refactor: Optimize CNN grayscale computation
Compute gray once per fragment using dot() instead of per-layer.
Pass gray as f32 parameter to conv functions instead of vec4 original.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'workspaces')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv3x3.wgsl | 16 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv5x5.wgsl | 14 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_layer.wgsl | 5 |
3 files changed, 13 insertions, 22 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl index 96ddf5b..79b0350 100644 --- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl @@ -15,7 +15,7 @@ fn cnn_conv3x3_7to4_src( // Compute grayscale from original (converted in [-1,1]) let original = (textureSample(tex, samp, uv) - 0.5) * 2.0; - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); // Normalize UV to [-1,1] let uv_norm = (uv - 0.5) * 2.0; @@ -52,7 +52,7 @@ fn cnn_conv3x3_7to4_src( } // Inner layers: 7→4 channels (RGBD output) -// Assumes 'tex' and 'original' are already normalized to [-1,1] +// Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally // weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) fn cnn_conv3x3_7to4( @@ -60,14 +60,11 @@ fn cnn_conv3x3_7to4( samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - original: vec4<f32>, + gray: f32, weights: array<array<f32, 8>, 36> ) -> vec4<f32> { let step = 1.0 / resolution; - // Compute grayscale from original (already in [-1,1]) - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; - // Normalize UV to [-1,1] let uv_norm = (uv - 0.5) * 2.0; @@ -103,7 +100,7 @@ fn cnn_conv3x3_7to4( } // Final layer: 7→1 channel (scalar output) -// Assumes 'tex' and 'original' are already normalized to [-1,1] +// Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally // weights: array<array<f32, 8>, 9> (9 positions, each with 7 weights + bias) fn cnn_conv3x3_7to1( @@ -111,14 +108,11 @@ fn cnn_conv3x3_7to1( samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - original: vec4<f32>, + gray: f32, weights: array<array<f32, 8>, 9> ) -> f32 { let step = 1.0 / resolution; - // Compute grayscale from original (already in [-1,1]) - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; - // Normalize UV to [-1,1] let uv_norm = (uv - 0.5) * 2.0; diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl index 0f261dd..5570589 100644 --- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl @@ -1,5 +1,5 @@ // 5×5 variant for 7→4 channels (RGBD output) -// Assumes 'tex' and 'original' are already normalized to [-1,1] +// Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally // weights: array<array<f32, 8>, 100> (25 positions × 4 channels, each with 7 weights + bias) fn cnn_conv5x5_7to4( @@ -7,12 +7,10 @@ fn cnn_conv5x5_7to4( samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - original: vec4<f32>, + gray: f32, weights: array<array<f32, 8>, 100> ) -> vec4<f32> { let step = 1.0 / resolution; - - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; let uv_norm = (uv - 0.5) * 2.0; var sum = vec4<f32>(0.0); @@ -44,7 +42,7 @@ fn cnn_conv5x5_7to4( } // 5×5 variant for 7→1 channel (scalar output) -// Assumes 'tex' and 'original' are already normalized to [-1,1] +// Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally // weights: array<array<f32, 8>, 25> (25 positions, each with 7 weights + bias) fn cnn_conv5x5_7to1( @@ -52,12 +50,10 @@ fn cnn_conv5x5_7to1( samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - original: vec4<f32>, + gray: f32, weights: array<array<f32, 8>, 25> ) -> f32 { let step = 1.0 / resolution; - - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; let uv_norm = (uv - 0.5) * 2.0; var sum = 0.0; @@ -96,7 +92,7 @@ fn cnn_conv5x5_7to4_src( let step = 1.0 / resolution; let original = (textureSample(tex, samp, uv) - 0.5) * 2.0; - let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); let uv_norm = (uv - 0.5) * 2.0; var sum = vec4<f32>(0.0); diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl index 3f970df..e67ad31 100644 --- a/workspaces/main/shaders/cnn/cnn_layer.wgsl +++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl @@ -32,6 +32,7 @@ struct CNNLayerParams { let uv = p.xy / uniforms.resolution; let original_raw = textureSample(original_input, smplr, uv); let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1] + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); var result = vec4<f32>(0.0); // Layer 0: 7→4 (RGBD output, normalizes [0,1] input) @@ -42,12 +43,12 @@ struct CNNLayerParams { } else if (params.layer_index == 1) { result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution, - original, weights_layer1); + gray, weights_layer1); result = cnn_tanh(result); // Keep in [-1,1] } else if (params.layer_index == 2) { let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution, - original, weights_layer2); + gray, weights_layer2); // gray_out already in [0,1] from clipped training result = vec4<f32>(gray_out, gray_out, gray_out, 1.0); return mix(original_raw, result, params.blend_amount); // [0,1] |
