diff options
Diffstat (limited to 'workspaces/main/shaders')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv3x3.wgsl | 100 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_layer.wgsl | 24 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_weights_generated.wgsl | 194 |
3 files changed, 295 insertions, 23 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl index 168c9e2..df58b4d 100644 --- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl @@ -51,3 +51,103 @@ fn cnn_conv3x3_with_coord( return sum; } + +// Inner layers: 7→4 channels (RGBD output) +// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) +fn cnn_conv3x3_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + original: vec4<f32>, + weights: array<array<f32, 8>, 36> +) -> vec4<f32> { + let step = 1.0 / resolution; + + // Compute grayscale from original and normalize to [-1,1] + let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = (gray_01 - 0.5) * 2.0; + + // Normalize UV to [-1,1] + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + + var pos = 0; + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd_01 = textureSample(tex, samp, uv + offset); + + // Normalize RGBD to [-1,1] + let rgbd = (rgbd_01 - 0.5) * 2.0; + + // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1] + let inputs = array<f32, 7>( + rgbd.r, rgbd.g, rgbd.b, rgbd.a, + uv_norm.x, uv_norm.y, gray + ); + + // Accumulate for each output channel (RGBD) + for (var out_c = 0; out_c < 4; out_c++) { + let idx = pos * 4 + out_c; + var channel_sum = weights[idx][7]; // Bias (8th element) + for (var in_c = 0; in_c < 7; in_c++) { + channel_sum += weights[idx][in_c] * inputs[in_c]; + } + sum[out_c] += channel_sum; + } + + pos++; + } + } + + return sum; // Output in [-1,1] range +} + +// Final layer: 7→1 channel (scalar output) +// weights: array<array<f32, 8>, 9> (9 positions, each with 7 weights + bias) +fn cnn_conv3x3_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + original: vec4<f32>, + weights: array<array<f32, 8>, 9> +) -> f32 { + let step = 1.0 / resolution; + + // Normalize grayscale to [-1,1] + let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = (gray_01 - 0.5) * 2.0; + + // Normalize UV to [-1,1] + let uv_norm = (uv - 0.5) * 2.0; + + var sum = 0.0; + + var pos = 0; + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd_01 = textureSample(tex, samp, uv + offset); + + // Normalize RGBD to [-1,1] + let rgbd = (rgbd_01 - 0.5) * 2.0; + + // 7-channel input all in [-1,1] + sum += weights[pos][0] * rgbd.r; + sum += weights[pos][1] * rgbd.g; + sum += weights[pos][2] * rgbd.b; + sum += weights[pos][3] * rgbd.a; + sum += weights[pos][4] * uv_norm.x; + sum += weights[pos][5] * uv_norm.y; + sum += weights[pos][6] * gray; + sum += weights[pos][7]; // Bias + + pos++; + } + } + + return sum; // Output in [-1,1], needs denormalization +} diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl index b2bab26..5834f78 100644 --- a/workspaces/main/shaders/cnn/cnn_layer.wgsl +++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl @@ -1,5 +1,6 @@ // CNN layer shader - uses modular convolution snippets // Supports multi-pass rendering with residual connections +// DO NOT EDIT - Generated by train_cnn.py @group(0) @binding(0) var smplr: sampler; @group(0) @binding(1) var txt: texture_2d<f32>; @@ -11,12 +12,13 @@ struct CNNLayerParams { layer_index: i32, - use_residual: i32, + blend_amount: f32, _pad: vec2<f32>, }; @group(0) @binding(2) var<uniform> uniforms: CommonUniforms; @group(0) @binding(3) var<uniform> params: CNNLayerParams; +@group(0) @binding(4) var original_input: texture_2d<f32>; @vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> { var pos = array<vec2<f32>, 3>( @@ -27,6 +29,8 @@ struct CNNLayerParams { @fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> { let uv = p.xy / uniforms.resolution; + let input = textureSample(txt, smplr, uv); + let original = textureSample(original_input, smplr, uv); var result = vec4<f32>(0.0); // Layer 0 uses coordinate-aware convolution @@ -35,12 +39,18 @@ struct CNNLayerParams { rgba_weights_layer0, coord_weights_layer0, bias_layer0); result = cnn_tanh(result); } - - // Residual connection - if (params.use_residual != 0) { - let input = textureSample(txt, smplr, uv); - result = input + result * 0.3; + else if (params.layer_index == 1) { + result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, + weights_layer1, bias_layer1); + result = cnn_tanh(result); + } + else if (params.layer_index == 2) { + result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, + weights_layer2, bias_layer2); + } + else { + result = input; } - return result; + return mix(original, result, params.blend_amount); } diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl index e0a7dc4..6052ac5 100644 --- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl +++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl @@ -1,23 +1,185 @@ -// Generated CNN weights and biases -// DO NOT EDIT MANUALLY - regenerate with scripts/train_cnn.py +// Auto-generated CNN weights +// DO NOT EDIT - Generated by train_cnn.py -// Placeholder identity-like weights for initial testing -// Layer 0: 3x3 convolution with coordinate awareness const rgba_weights_layer0: array<mat4x4<f32>, 9> = array( - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + mat4x4<f32>( + -0.181929, -0.244329, -0.354404, 0.0, + -0.291597, -0.195653, 0.081896, 0.0, + 0.081595, 0.164081, -0.236318, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.731888, 0.717648, 0.524081, 0.0, + -0.029760, -0.208000, 0.008438, 0.0, + 0.442082, 0.354681, 0.049288, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.623141, -0.695759, -0.087885, 0.0, + 0.043135, 0.071979, 0.213065, 0.0, + 0.011581, 0.110995, 0.034100, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.170016, 0.188298, 0.134083, 0.0, + -0.222954, -0.088011, 0.015668, 0.0, + 0.921836, 0.437158, 0.061577, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 1.431940, 1.148113, 1.238067, 0.0, + -0.212535, 0.366860, 0.320956, 0.0, + 0.771192, 0.765570, 0.029189, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.171088, 0.000155, 0.212552, 0.0, + 0.029536, 0.447892, 0.041381, 0.0, + 0.011807, -0.167281, -0.200702, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.668151, -0.813927, -0.132108, 0.0, + -0.156250, 0.179112, -0.069585, 0.0, + 0.403347, 0.482877, 0.182611, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.609871, -0.768480, -0.590538, 0.0, + -0.171854, 0.150167, 0.105694, 0.0, + -0.059052, 0.066999, -0.244222, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.112983, -0.066299, 0.117696, 0.0, + -0.172541, 0.095008, -0.160754, 0.0, + -0.369667, -0.000628, 0.163602, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) ); const coord_weights_layer0 = mat2x4<f32>( - 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0 + 0.059076, -0.026617, -0.005155, 0.0, + 0.135407, -0.090329, 0.058216, 0.0 ); -const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0); +const bias_layer0 = vec4<f32>(-0.526177, -0.569862, -1.370040, 0.0); + +const weights_layer1: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + 0.180029, -1.107249, 0.570741, 0.0, + -0.098536, 0.079545, -0.083257, 0.0, + -0.020066, 0.333084, 0.039506, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 3.068946, -1.783570, -0.550517, 0.0, + -0.296369, -0.080958, 0.040260, 0.0, + -0.093713, -0.212577, -0.110011, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 2.282564, -0.538192, -0.793214, 0.0, + -0.395788, 0.130881, 0.078571, 0.0, + -0.041375, 0.061666, 0.045651, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.267284, -1.971639, -0.099616, 0.0, + -0.084432, 0.139794, 0.007091, 0.0, + -0.103042, -0.104340, 0.067299, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -5.233469, -2.252747, -3.555217, 0.0, + 0.647940, -0.178858, 0.351633, 0.0, + -0.014237, -0.505881, 0.165940, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.121700, -0.677386, -2.435040, 0.0, + 0.084806, -0.028000, 0.380387, 0.0, + -0.020906, -0.279161, 0.041915, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 2.982562, -0.298441, -0.147775, 0.0, + -0.291832, 0.102875, -0.128590, 0.0, + -0.091786, 0.104389, -0.188678, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -4.434978, -0.261830, -2.436411, 0.0, + 0.349188, -0.245908, 0.272592, 0.0, + 0.010322, -0.148525, -0.031531, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.129886, 1.516168, -0.755576, 0.0, + 0.133138, -0.260276, 0.028059, 0.0, + 0.001185, 0.141547, -0.003606, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) +); + +const bias_layer1 = vec4<f32>(1.367986, -1.148709, -0.650040, 0.0); + +const weights_layer2: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + -0.137003, -0.289376, 0.625000, 0.0, + -0.120120, -0.238968, 0.448432, 0.0, + -0.142094, -0.253706, 0.458181, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.337017, -0.757585, 0.135953, 0.0, + -0.304432, -0.553491, 0.419907, 0.0, + -0.313585, -0.467667, 0.615326, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.161089, -0.328735, 0.612679, 0.0, + -0.137144, -0.172882, 0.176362, 0.0, + -0.153195, -0.061571, 0.173977, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.227814, -0.544193, -0.564658, 0.0, + -0.211743, -0.430586, 0.080349, 0.0, + -0.214442, -0.417501, 0.880266, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.435370, -0.295169, -0.865976, 0.0, + -0.423147, -0.274780, 0.323049, 0.0, + -0.411180, -0.062517, 1.099769, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.199573, -0.488030, -0.396440, 0.0, + -0.187844, -0.360516, -0.156646, 0.0, + -0.188681, -0.292304, -0.134645, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.123218, -0.287990, 0.154656, 0.0, + -0.112954, -0.282778, 0.498742, 0.0, + -0.139083, -0.319337, 1.112621, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.267477, -0.691374, -0.028960, 0.0, + -0.246348, -0.585583, 0.401194, 0.0, + -0.253279, -0.562875, 1.105818, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.083133, -0.131627, 0.460039, 0.0, + -0.071126, -0.108601, 0.163545, 0.0, + -0.092579, -0.110020, 0.131282, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) +); + +const bias_layer2 = vec4<f32>(-1.805686, -0.798340, 0.462318, 0.0); + |
