diff options
Diffstat (limited to 'workspaces/main/shaders/cnn')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_layer.wgsl | 25 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_weights_generated.wgsl | 194 |
2 files changed, 196 insertions, 23 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl index b2bab26..2285ef9 100644 --- a/workspaces/main/shaders/cnn/cnn_layer.wgsl +++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl @@ -1,5 +1,6 @@ // CNN layer shader - uses modular convolution snippets // Supports multi-pass rendering with residual connections +// DO NOT EDIT - Generated by train_cnn.py @group(0) @binding(0) var smplr: sampler; @group(0) @binding(1) var txt: texture_2d<f32>; @@ -11,12 +12,13 @@ struct CNNLayerParams { layer_index: i32, - use_residual: i32, + blend_amount: f32, _pad: vec2<f32>, }; @group(0) @binding(2) var<uniform> uniforms: CommonUniforms; @group(0) @binding(3) var<uniform> params: CNNLayerParams; +@group(0) @binding(4) var original_input: texture_2d<f32>; @vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> { var pos = array<vec2<f32>, 3>( @@ -27,6 +29,8 @@ struct CNNLayerParams { @fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> { let uv = p.xy / uniforms.resolution; + let input = textureSample(txt, smplr, uv); + let original = textureSample(original_input, smplr, uv); var result = vec4<f32>(0.0); // Layer 0 uses coordinate-aware convolution @@ -35,12 +39,19 @@ struct CNNLayerParams { rgba_weights_layer0, coord_weights_layer0, bias_layer0); result = cnn_tanh(result); } - - // Residual connection - if (params.use_residual != 0) { - let input = textureSample(txt, smplr, uv); - result = input + result * 0.3; + else if (params.layer_index == 1) { + result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, + weights_layer1, bias_layer1); + result = cnn_tanh(result); + } + else if (params.layer_index == 2) { + result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, + weights_layer2, bias_layer2); + } + else { + result = input; } - return result; + // Blend with ORIGINAL input from layer 0 + return mix(original, result, params.blend_amount); } diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl index e0a7dc4..6052ac5 100644 --- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl +++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl @@ -1,23 +1,185 @@ -// Generated CNN weights and biases -// DO NOT EDIT MANUALLY - regenerate with scripts/train_cnn.py +// Auto-generated CNN weights +// DO NOT EDIT - Generated by train_cnn.py -// Placeholder identity-like weights for initial testing -// Layer 0: 3x3 convolution with coordinate awareness const rgba_weights_layer0: array<mat4x4<f32>, 9> = array( - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + mat4x4<f32>( + -0.181929, -0.244329, -0.354404, 0.0, + -0.291597, -0.195653, 0.081896, 0.0, + 0.081595, 0.164081, -0.236318, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.731888, 0.717648, 0.524081, 0.0, + -0.029760, -0.208000, 0.008438, 0.0, + 0.442082, 0.354681, 0.049288, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.623141, -0.695759, -0.087885, 0.0, + 0.043135, 0.071979, 0.213065, 0.0, + 0.011581, 0.110995, 0.034100, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.170016, 0.188298, 0.134083, 0.0, + -0.222954, -0.088011, 0.015668, 0.0, + 0.921836, 0.437158, 0.061577, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 1.431940, 1.148113, 1.238067, 0.0, + -0.212535, 0.366860, 0.320956, 0.0, + 0.771192, 0.765570, 0.029189, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.171088, 0.000155, 0.212552, 0.0, + 0.029536, 0.447892, 0.041381, 0.0, + 0.011807, -0.167281, -0.200702, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.668151, -0.813927, -0.132108, 0.0, + -0.156250, 0.179112, -0.069585, 0.0, + 0.403347, 0.482877, 0.182611, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.609871, -0.768480, -0.590538, 0.0, + -0.171854, 0.150167, 0.105694, 0.0, + -0.059052, 0.066999, -0.244222, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.112983, -0.066299, 0.117696, 0.0, + -0.172541, 0.095008, -0.160754, 0.0, + -0.369667, -0.000628, 0.163602, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) ); const coord_weights_layer0 = mat2x4<f32>( - 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0 + 0.059076, -0.026617, -0.005155, 0.0, + 0.135407, -0.090329, 0.058216, 0.0 ); -const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0); +const bias_layer0 = vec4<f32>(-0.526177, -0.569862, -1.370040, 0.0); + +const weights_layer1: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + 0.180029, -1.107249, 0.570741, 0.0, + -0.098536, 0.079545, -0.083257, 0.0, + -0.020066, 0.333084, 0.039506, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 3.068946, -1.783570, -0.550517, 0.0, + -0.296369, -0.080958, 0.040260, 0.0, + -0.093713, -0.212577, -0.110011, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 2.282564, -0.538192, -0.793214, 0.0, + -0.395788, 0.130881, 0.078571, 0.0, + -0.041375, 0.061666, 0.045651, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.267284, -1.971639, -0.099616, 0.0, + -0.084432, 0.139794, 0.007091, 0.0, + -0.103042, -0.104340, 0.067299, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -5.233469, -2.252747, -3.555217, 0.0, + 0.647940, -0.178858, 0.351633, 0.0, + -0.014237, -0.505881, 0.165940, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.121700, -0.677386, -2.435040, 0.0, + 0.084806, -0.028000, 0.380387, 0.0, + -0.020906, -0.279161, 0.041915, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 2.982562, -0.298441, -0.147775, 0.0, + -0.291832, 0.102875, -0.128590, 0.0, + -0.091786, 0.104389, -0.188678, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -4.434978, -0.261830, -2.436411, 0.0, + 0.349188, -0.245908, 0.272592, 0.0, + 0.010322, -0.148525, -0.031531, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + 0.129886, 1.516168, -0.755576, 0.0, + 0.133138, -0.260276, 0.028059, 0.0, + 0.001185, 0.141547, -0.003606, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) +); + +const bias_layer1 = vec4<f32>(1.367986, -1.148709, -0.650040, 0.0); + +const weights_layer2: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + -0.137003, -0.289376, 0.625000, 0.0, + -0.120120, -0.238968, 0.448432, 0.0, + -0.142094, -0.253706, 0.458181, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.337017, -0.757585, 0.135953, 0.0, + -0.304432, -0.553491, 0.419907, 0.0, + -0.313585, -0.467667, 0.615326, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.161089, -0.328735, 0.612679, 0.0, + -0.137144, -0.172882, 0.176362, 0.0, + -0.153195, -0.061571, 0.173977, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.227814, -0.544193, -0.564658, 0.0, + -0.211743, -0.430586, 0.080349, 0.0, + -0.214442, -0.417501, 0.880266, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.435370, -0.295169, -0.865976, 0.0, + -0.423147, -0.274780, 0.323049, 0.0, + -0.411180, -0.062517, 1.099769, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.199573, -0.488030, -0.396440, 0.0, + -0.187844, -0.360516, -0.156646, 0.0, + -0.188681, -0.292304, -0.134645, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.123218, -0.287990, 0.154656, 0.0, + -0.112954, -0.282778, 0.498742, 0.0, + -0.139083, -0.319337, 1.112621, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.267477, -0.691374, -0.028960, 0.0, + -0.246348, -0.585583, 0.401194, 0.0, + -0.253279, -0.562875, 1.105818, 0.0, + 0.0, 0.0, 0.0, 0.0, + ), + mat4x4<f32>( + -0.083133, -0.131627, 0.460039, 0.0, + -0.071126, -0.108601, 0.163545, 0.0, + -0.092579, -0.110020, 0.131282, 0.0, + 0.0, 0.0, 0.0, 0.0, + ) +); + +const bias_layer2 = vec4<f32>(-1.805686, -0.798340, 0.462318, 0.0); + |
