diff options
Diffstat (limited to 'workspaces/main/shaders/cnn/cnn_conv3x3.wgsl')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv3x3.wgsl | 90 |
1 files changed, 30 insertions, 60 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl index 00eae22..c032767 100644 --- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl @@ -1,15 +1,15 @@ -// 3x3 convolution with weight indexing +// 3x3 convolution (vec4-optimized) // Source layers: 7→4 channels (RGBD output) // Assumes 'tex' (the input) is *not* normalized to [-1,1], but is [0,1] // UV coordinates remain in [0,1] and are normalized internally -// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) +// weights: array<vec4<f32>, 72> (9 pos × 4 ch × 2 vec4) fn cnn_conv3x3_7to4_src( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - weights: array<array<f32, 8>, 36> + weights: array<vec4<f32>, 72> ) -> vec4<f32> { let step = 1.0 / resolution; @@ -26,42 +26,31 @@ fn cnn_conv3x3_7to4_src( for (var dy = -1; dy <= 1; dy++) { for (var dx = -1; dx <= 1; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgbd = (textureSample(tex, samp, uv + offset) - .5) * 2.0; // convert to [-1,1] + let rgbd = (textureSample(tex, samp, uv + offset) - .5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); - // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1] - let inputs = array<f32, 7>( - rgbd.r, rgbd.g, rgbd.b, rgbd.a, - uv_norm.x, uv_norm.y, gray - ); - - // Accumulate for each output channel (RGBD) - for (var out_c = 0; out_c < 4; out_c++) { - let idx = pos * 4 + out_c; - var channel_sum = weights[idx][7]; // Bias (8th element) - for (var in_c = 0; in_c < 7; in_c++) { - channel_sum += weights[idx][in_c] * inputs[in_c]; - } - sum[out_c] += channel_sum; - } - - pos++; + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; } } - return sum; // Output in [-1,1] range + return sum; } -// Inner layers: 7→4 channels (RGBD output) +// Inner layers: 7→4 channels (vec4-optimized) // Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally -// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) +// weights: array<vec4<f32>, 72> (9 pos × 4 ch × 2 vec4) fn cnn_conv3x3_7to4( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, gray: f32, - weights: array<array<f32, 8>, 36> + weights: array<vec4<f32>, 72> ) -> vec4<f32> { let step = 1.0 / resolution; @@ -74,42 +63,31 @@ fn cnn_conv3x3_7to4( for (var dy = -1; dy <= 1; dy++) { for (var dx = -1; dx <= 1; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] - - // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1] - let inputs = array<f32, 7>( - rgbd.r, rgbd.g, rgbd.b, rgbd.a, - uv_norm.x, uv_norm.y, gray - ); + let rgbd = textureSample(tex, samp, uv + offset); + let in1 = vec4<f32>(uv_norm, gray, 1.0); - // Accumulate for each output channel (RGBD) - for (var out_c = 0; out_c < 4; out_c++) { - let idx = pos * 4 + out_c; - var channel_sum = weights[idx][7]; // Bias (8th element) - for (var in_c = 0; in_c < 7; in_c++) { - channel_sum += weights[idx][in_c] * inputs[in_c]; - } - sum[out_c] += channel_sum; - } - - pos++; + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; } } - return sum; // Output in [-1,1] range + return sum; } -// Final layer: 7→1 channel (scalar output) +// Final layer: 7→1 channel (vec4-optimized) // Assumes 'tex' is already normalized to [-1,1] // UV coordinates remain in [0,1] and are normalized internally -// weights: array<array<f32, 8>, 9> (9 positions, each with 7 weights + bias) +// weights: array<vec4<f32>, 18> (9 pos × 2 vec4) fn cnn_conv3x3_7to1( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, gray: f32, - weights: array<array<f32, 8>, 9> + weights: array<vec4<f32>, 18> ) -> f32 { let step = 1.0 / resolution; @@ -122,21 +100,13 @@ fn cnn_conv3x3_7to1( for (var dy = -1; dy <= 1; dy++) { for (var dx = -1; dx <= 1; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] - - // 7-channel input all in [-1,1] - sum += weights[pos][0] * rgbd.r; - sum += weights[pos][1] * rgbd.g; - sum += weights[pos][2] * rgbd.b; - sum += weights[pos][3] * rgbd.a; - sum += weights[pos][4] * uv_norm.x; - sum += weights[pos][5] * uv_norm.y; - sum += weights[pos][6] * gray; - sum += weights[pos][7]; // Bias + let rgbd = textureSample(tex, samp, uv + offset); + let in1 = vec4<f32>(uv_norm, gray, 1.0); - pos++; + sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); + pos += 2; } } - return clamp(sum, 0.0, 1.0); // Match PyTorch clamp + return clamp(sum, 0.0, 1.0); } |
