diff options
Diffstat (limited to 'workspaces/main/shaders/cnn')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv3x3.wgsl | 137 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv5x5.wgsl | 78 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_layer.wgsl | 31 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_weights_generated.wgsl | 170 |
4 files changed, 344 insertions, 72 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl index 168c9e2..96ddf5b 100644 --- a/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl @@ -1,53 +1,148 @@ // 3x3 convolution with weight indexing -// Samples 9 pixels, applies mat4 weights per sample -fn cnn_conv3x3( +// Source layers: 7→4 channels (RGBD output) +// Assumes 'tex' (the input) is *not* normalized to [-1,1], but is [0,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) +fn cnn_conv3x3_7to4_src( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - weights: array<mat4x4<f32>, 9>, - bias: vec4<f32> + weights: array<array<f32, 8>, 36> ) -> vec4<f32> { let step = 1.0 / resolution; - var sum = bias; - var idx = 0; + // Compute grayscale from original (converted in [-1,1]) + let original = (textureSample(tex, samp, uv) - 0.5) * 2.0; + let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + + // Normalize UV to [-1,1] + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + + var pos = 0; for (var dy = -1; dy <= 1; dy++) { for (var dx = -1; dx <= 1; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let sample = textureSample(tex, samp, uv + offset); - sum += weights[idx] * sample; - idx++; + let rgbd = (textureSample(tex, samp, uv + offset) - .5) * 2.0; // convert to [-1,1] + + // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1] + let inputs = array<f32, 7>( + rgbd.r, rgbd.g, rgbd.b, rgbd.a, + uv_norm.x, uv_norm.y, gray + ); + + // Accumulate for each output channel (RGBD) + for (var out_c = 0; out_c < 4; out_c++) { + let idx = pos * 4 + out_c; + var channel_sum = weights[idx][7]; // Bias (8th element) + for (var in_c = 0; in_c < 7; in_c++) { + channel_sum += weights[idx][in_c] * inputs[in_c]; + } + sum[out_c] += channel_sum; + } + + pos++; } } - return sum; + return sum; // Output in [-1,1] range } -fn cnn_conv3x3_with_coord( +// Inner layers: 7→4 channels (RGBD output) +// Assumes 'tex' and 'original' are already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<array<f32, 8>, 36> (9 positions × 4 channels, each with 7 weights + bias) +fn cnn_conv3x3_7to4( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - rgba_weights: array<mat4x4<f32>, 9>, - coord_weights: mat2x4<f32>, - bias: vec4<f32> + original: vec4<f32>, + weights: array<array<f32, 8>, 36> ) -> vec4<f32> { let step = 1.0 / resolution; - var sum = bias; - sum += coord_weights * uv; + // Compute grayscale from original (already in [-1,1]) + let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + + // Normalize UV to [-1,1] + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + + var pos = 0; + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] + + // 7-channel input: [R,G,B,D, uv.x, uv.y, gray] all in [-1,1] + let inputs = array<f32, 7>( + rgbd.r, rgbd.g, rgbd.b, rgbd.a, + uv_norm.x, uv_norm.y, gray + ); + + // Accumulate for each output channel (RGBD) + for (var out_c = 0; out_c < 4; out_c++) { + let idx = pos * 4 + out_c; + var channel_sum = weights[idx][7]; // Bias (8th element) + for (var in_c = 0; in_c < 7; in_c++) { + channel_sum += weights[idx][in_c] * inputs[in_c]; + } + sum[out_c] += channel_sum; + } + + pos++; + } + } + + return sum; // Output in [-1,1] range +} + +// Final layer: 7→1 channel (scalar output) +// Assumes 'tex' and 'original' are already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<array<f32, 8>, 9> (9 positions, each with 7 weights + bias) +fn cnn_conv3x3_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + original: vec4<f32>, + weights: array<array<f32, 8>, 9> +) -> f32 { + let step = 1.0 / resolution; + + // Compute grayscale from original (already in [-1,1]) + let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + + // Normalize UV to [-1,1] + let uv_norm = (uv - 0.5) * 2.0; - var idx = 0; + var sum = 0.0; + + var pos = 0; for (var dy = -1; dy <= 1; dy++) { for (var dx = -1; dx <= 1; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgba = textureSample(tex, samp, uv + offset); - sum += rgba_weights[idx] * rgba; - idx++; + let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] + + // 7-channel input all in [-1,1] + sum += weights[pos][0] * rgbd.r; + sum += weights[pos][1] * rgbd.g; + sum += weights[pos][2] * rgbd.b; + sum += weights[pos][3] * rgbd.a; + sum += weights[pos][4] * uv_norm.x; + sum += weights[pos][5] * uv_norm.y; + sum += weights[pos][6] * gray; + sum += weights[pos][7]; // Bias + + pos++; } } - return sum; + return sum; // Output in [-1,1] } diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl index bd9abfa..5136740 100644 --- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl @@ -1,53 +1,85 @@ -// 5x5 convolution with 25 samples -// Applies mat4 weights per sample - -fn cnn_conv5x5( +// 5×5 variant for 7→4 channels (RGBD output) +// Assumes 'tex' and 'original' are already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<array<f32, 8>, 100> (25 positions × 4 channels, each with 7 weights + bias) +fn cnn_conv5x5_7to4( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - weights: array<mat4x4<f32>, 25>, - bias: vec4<f32> + original: vec4<f32>, + weights: array<array<f32, 8>, 100> ) -> vec4<f32> { let step = 1.0 / resolution; - var sum = bias; - var idx = 0; + + let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + var pos = 0; for (var dy = -2; dy <= 2; dy++) { for (var dx = -2; dx <= 2; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let sample = textureSample(tex, samp, uv + offset); - sum += weights[idx] * sample; - idx++; + let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] + + let inputs = array<f32, 7>( + rgbd.r, rgbd.g, rgbd.b, rgbd.a, + uv_norm.x, uv_norm.y, gray + ); + + for (var out_c = 0; out_c < 4; out_c++) { + let idx = pos * 4 + out_c; + var channel_sum = weights[idx][7]; + for (var in_c = 0; in_c < 7; in_c++) { + channel_sum += weights[idx][in_c] * inputs[in_c]; + } + sum[out_c] += channel_sum; + } + pos++; } } return sum; } -fn cnn_conv5x5_with_coord( +// 5×5 variant for 7→1 channel (scalar output) +// Assumes 'tex' and 'original' are already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<array<f32, 8>, 25> (25 positions, each with 7 weights + bias) +fn cnn_conv5x5_7to1( tex: texture_2d<f32>, samp: sampler, uv: vec2<f32>, resolution: vec2<f32>, - rgba_weights: array<mat4x4<f32>, 25>, - coord_weights: mat2x4<f32>, - bias: vec4<f32> -) -> vec4<f32> { + original: vec4<f32>, + weights: array<array<f32, 8>, 25> +) -> f32 { let step = 1.0 / resolution; - var sum = bias; - sum += coord_weights * uv; + let gray = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = 0.0; + var pos = 0; - var idx = 0; for (var dy = -2; dy <= 2; dy++) { for (var dx = -2; dx <= 2; dx++) { let offset = vec2<f32>(f32(dx), f32(dy)) * step; - let rgba = textureSample(tex, samp, uv + offset); - sum += rgba_weights[idx] * rgba; - idx++; + let rgbd = textureSample(tex, samp, uv + offset); // Already in [-1,1] + + sum += weights[pos][0] * rgbd.r; + sum += weights[pos][1] * rgbd.g; + sum += weights[pos][2] * rgbd.b; + sum += weights[pos][3] * rgbd.a; + sum += weights[pos][4] * uv_norm.x; + sum += weights[pos][5] * uv_norm.y; + sum += weights[pos][6] * gray; + sum += weights[pos][7]; // Bias + + pos++; } } - return sum; + return sum; // Output in [-1,1] } diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl index b2bab26..1b1b539 100644 --- a/workspaces/main/shaders/cnn/cnn_layer.wgsl +++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl @@ -1,5 +1,6 @@ // CNN layer shader - uses modular convolution snippets // Supports multi-pass rendering with residual connections +// DO NOT EDIT - Generated by train_cnn.py @group(0) @binding(0) var smplr: sampler; @group(0) @binding(1) var txt: texture_2d<f32>; @@ -7,16 +8,18 @@ #include "common_uniforms" #include "cnn_activation" #include "cnn_conv3x3" +#include "cnn_conv5x5" #include "cnn_weights_generated" struct CNNLayerParams { layer_index: i32, - use_residual: i32, + blend_amount: f32, _pad: vec2<f32>, }; @group(0) @binding(2) var<uniform> uniforms: CommonUniforms; @group(0) @binding(3) var<uniform> params: CNNLayerParams; +@group(0) @binding(4) var original_input: texture_2d<f32>; @vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> { var pos = array<vec2<f32>, 3>( @@ -27,20 +30,28 @@ struct CNNLayerParams { @fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> { let uv = p.xy / uniforms.resolution; + let original = (textureSample(original_input, smplr, uv) - 0.5) * 2.0; // Normalize to [-1,1] var result = vec4<f32>(0.0); - // Layer 0 uses coordinate-aware convolution + // Layer 0: 7→4 (RGBD output) if (params.layer_index == 0) { - result = cnn_conv3x3_with_coord(txt, smplr, uv, uniforms.resolution, - rgba_weights_layer0, coord_weights_layer0, bias_layer0); - result = cnn_tanh(result); + result = cnn_conv3x3_7to4_src(txt, smplr, uv, uniforms.resolution, weights_layer0); + result = cnn_tanh(result); // Keep in [-1,1] } - - // Residual connection - if (params.use_residual != 0) { - let input = textureSample(txt, smplr, uv); - result = input + result * 0.3; + else if (params.layer_index == 1) { + result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution, + original, weights_layer1); + result = cnn_tanh(result); // Keep in [-1,1] } + else if (params.layer_index == 2) { // last layer + let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution, + original, weights_layer2); + // At this point here, 'gray_out' is what the training script should have learned. + // Below is some extra code for visual output, excluded from training: + result = vec4<f32>(gray_out, gray_out, gray_out, 1.0); // Keep in [-1,1] + let blended = mix(original, result, params.blend_amount); + return (blended + 1.0) * 0.5; // Denormalize to [0,1] for display + } return result; } diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl index e0a7dc4..e38669f 100644 --- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl +++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl @@ -1,23 +1,157 @@ -// Generated CNN weights and biases -// DO NOT EDIT MANUALLY - regenerate with scripts/train_cnn.py +// Auto-generated CNN weights +// DO NOT EDIT - Generated by train_cnn.py -// Placeholder identity-like weights for initial testing -// Layer 0: 3x3 convolution with coordinate awareness -const rgba_weights_layer0: array<mat4x4<f32>, 9> = array( - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - mat4x4<f32>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) +const weights_layer0: array<array<f32, 8>, 36> = array( + array<f32, 8>(0.074911, 0.143202, 0.086903, 0.070680, -0.031904, 0.122884, 0.191824, 0.071112), + array<f32, 8>(0.081964, 0.033505, 0.058371, -0.015971, -0.069732, -0.014956, 0.142861, 0.119666), + array<f32, 8>(0.231883, -0.160763, -0.147218, 0.161321, -0.031718, -0.065766, 0.093359, 0.171734), + array<f32, 8>(0.082047, 0.288492, 0.121087, 0.001740, -0.104745, -0.071150, 0.031105, 0.037989), + array<f32, 8>(0.139236, 0.160690, 0.022091, 0.070994, 0.008793, 0.059247, 0.215077, 0.071112), + array<f32, 8>(0.128842, 0.268017, -0.031546, 0.068152, -0.073793, 0.124100, 0.252295, 0.119666), + array<f32, 8>(0.077193, -0.080009, -0.160674, 0.101131, -0.152167, -0.035271, 0.067397, 0.171734), + array<f32, 8>(-0.073119, 0.204309, 0.005654, 0.101254, -0.063530, -0.040801, 0.213393, 0.037989), + array<f32, 8>(-0.024175, 0.018739, 0.095518, 0.096945, 0.088315, 0.079085, -0.069127, 0.071112), + array<f32, 8>(0.219014, 0.218505, 0.014228, 0.014379, 0.075954, -0.001065, 0.201142, 0.119666), + array<f32, 8>(0.182743, -0.041270, -0.085458, 0.092904, 0.020316, 0.036077, 0.020220, 0.171734), + array<f32, 8>(-0.210247, -0.072180, 0.017628, 0.084834, 0.050409, -0.067274, -0.130565, 0.037989), + array<f32, 8>(0.071649, -0.072076, -0.109385, -0.012436, 0.041505, -0.013451, -0.068780, 0.071112), + array<f32, 8>(0.083389, 0.133852, -0.018137, 0.086250, -0.006205, 0.052853, 0.137369, 0.119666), + array<f32, 8>(0.023275, 0.036871, -0.092898, -0.059569, -0.029758, -0.089218, -0.031705, 0.171734), + array<f32, 8>(0.054874, 0.290596, 0.157026, -0.127200, 0.054010, -0.163627, 0.185273, 0.037989), + array<f32, 8>(0.069455, -0.122527, 0.010922, -0.051404, -0.067941, 0.122001, 0.034784, 0.071112), + array<f32, 8>(0.263187, 0.346644, 0.094376, 0.080049, -0.013980, -0.020629, 0.287019, 0.119666), + array<f32, 8>(0.078601, -0.045813, 0.048391, 0.107248, -0.001537, 0.003619, 0.040853, 0.171734), + array<f32, 8>(-0.052910, 0.333324, -0.028273, 0.111413, 0.059925, 0.054957, 0.257592, 0.037989), + array<f32, 8>(0.037894, 0.001266, 0.039858, 0.027731, 0.156182, 0.094188, 0.021791, 0.071112), + array<f32, 8>(0.220401, 0.241493, 0.138405, 0.082160, 0.144517, -0.050410, 0.257101, 0.119666), + array<f32, 8>(0.055409, -0.103410, 0.049778, -0.023193, -0.116368, -0.085046, 0.047003, 0.171734), + array<f32, 8>(0.019721, 0.099621, 0.005697, -0.069641, -0.100712, 0.044279, -0.104894, 0.037989), + array<f32, 8>(0.132833, 0.144224, 0.075612, -0.052095, -0.027924, 0.029124, -0.012077, 0.071112), + array<f32, 8>(0.146387, 0.098381, 0.131536, 0.034274, -0.073611, 0.080596, 0.124333, 0.119666), + array<f32, 8>(0.118243, -0.165692, -0.091107, 0.001822, 0.003771, -0.053877, -0.045592, 0.171734), + array<f32, 8>(-0.146034, 0.167379, 0.036433, -0.074485, 0.047772, 0.007719, -0.057026, 0.037989), + array<f32, 8>(-0.105517, -0.143677, 0.006013, 0.038752, 0.082525, -0.070290, -0.082964, 0.071112), + array<f32, 8>(0.084325, 0.192342, 0.005734, 0.083787, 0.010618, 0.076732, 0.206159, 0.119666), + array<f32, 8>(0.025873, -0.002030, -0.008453, 0.189578, 0.077363, 0.014099, 0.086760, 0.171734), + array<f32, 8>(-0.040145, 0.209639, 0.131112, 0.021154, -0.046391, -0.055185, 0.110424, 0.037989), + array<f32, 8>(-0.091272, -0.149872, -0.018825, 0.109157, 0.037674, -0.067088, -0.199940, 0.071112), + array<f32, 8>(0.170814, 0.171591, -0.039657, 0.146638, -0.054918, -0.043451, 0.262821, 0.119666), + array<f32, 8>(0.183810, -0.147660, -0.144689, 0.045301, 0.055273, 0.017425, 0.136362, 0.171734), + array<f32, 8>(-0.078196, 0.116630, -0.138657, -0.140199, -0.052198, -0.040295, -0.093252, 0.037989) ); -const coord_weights_layer0 = mat2x4<f32>( - 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0 +const weights_layer1: array<array<f32, 8>, 100> = array( + array<f32, 8>(0.016434, 0.032528, 0.014184, -0.048073, 0.017625, 0.025898, 0.035017, -0.024743), + array<f32, 8>(-0.086195, 0.041682, 0.071182, -0.062423, -0.016809, -0.004450, -0.035094, 0.087283), + array<f32, 8>(-0.070627, 0.033625, 0.025104, -0.086014, -0.037153, -0.019897, 0.046995, -0.025117), + array<f32, 8>(-0.042999, 0.043590, -0.107547, 0.114733, -0.006566, 0.067189, 0.042215, -0.019137), + array<f32, 8>(-0.105321, 0.188555, -0.033070, 0.005478, -0.019701, -0.006125, -0.006207, -0.024743), + array<f32, 8>(-0.018644, 0.021491, 0.042713, 0.047051, 0.009930, -0.074932, 0.016032, 0.087283), + array<f32, 8>(-0.036977, 0.022728, -0.031943, -0.134525, -0.024105, 0.022550, 0.038872, -0.025117), + array<f32, 8>(-0.017196, 0.102869, -0.028182, 0.153466, -0.024727, 0.008610, -0.029993, -0.019137), + array<f32, 8>(-0.135262, 0.264086, 0.052894, 0.104268, -0.044918, 0.085902, 0.119113, -0.024743), + array<f32, 8>(0.052648, 0.081481, 0.063582, 0.016832, 0.100333, -0.095727, 0.022089, 0.087283), + array<f32, 8>(0.028176, 0.006417, -0.010806, -0.049843, 0.010670, 0.058400, 0.051595, -0.025117), + array<f32, 8>(-0.078976, 0.040644, -0.116569, 0.145770, 0.019023, 0.071229, 0.056151, -0.019137), + array<f32, 8>(-0.028693, 0.154285, -0.019369, 0.111634, 0.022241, -0.015484, 0.039056, -0.024743), + array<f32, 8>(-0.052688, -0.046999, -0.000280, -0.024856, 0.012262, 0.028524, -0.028633, 0.087283), + array<f32, 8>(-0.004525, 0.052883, 0.002108, -0.096774, 0.052697, -0.055029, -0.022623, -0.025117), + array<f32, 8>(-0.076488, 0.013246, -0.097773, 0.023400, 0.027572, 0.041318, 0.012556, -0.019137), + array<f32, 8>(0.028093, 0.007624, 0.021861, -0.079392, 0.053487, 0.065200, -0.084020, -0.024743), + array<f32, 8>(-0.027503, 0.010973, 0.077242, 0.105956, 0.003837, -0.032827, 0.062214, 0.087283), + array<f32, 8>(0.028159, 0.036260, 0.051032, -0.057339, -0.032511, -0.019800, -0.113611, -0.025117), + array<f32, 8>(-0.004438, 0.024692, -0.151404, 0.097579, -0.031042, 0.067771, -0.062624, -0.019137), + array<f32, 8>(-0.053284, 0.062195, 0.018403, -0.145339, 0.008091, -0.048359, 0.060338, -0.024743), + array<f32, 8>(0.035264, 0.022147, 0.014877, -0.010450, 0.048411, -0.011475, -0.025409, 0.087283), + array<f32, 8>(-0.095181, 0.095906, 0.022414, -0.068326, -0.035929, 0.041247, -0.066456, -0.025117), + array<f32, 8>(0.011500, 0.097427, -0.072423, 0.068691, 0.006129, 0.025585, -0.066149, -0.019137), + array<f32, 8>(0.000253, 0.207033, 0.041903, -0.018208, 0.080300, 0.029738, 0.170740, -0.024743), + array<f32, 8>(0.118473, -0.002532, 0.082055, 0.029355, -0.017353, -0.094582, -0.028445, 0.087283), + array<f32, 8>(-0.167765, 0.166992, -0.051393, 0.018985, 0.000246, -0.060339, -0.036368, -0.025117), + array<f32, 8>(-0.037902, 0.123576, -0.135429, 0.018780, 0.069222, -0.048750, 0.010303, -0.019137), + array<f32, 8>(0.092400, 0.317862, 0.056507, 0.269526, 0.015330, -0.078774, 0.213070, -0.024743), + array<f32, 8>(0.147994, -0.056838, -0.046159, 0.069406, -0.025076, -0.018648, 0.019698, 0.087283), + array<f32, 8>(-0.063516, 0.051390, -0.043280, 0.053602, 0.046148, 0.032013, -0.012079, -0.025117), + array<f32, 8>(-0.069387, 0.008554, -0.016392, 0.041428, 0.069626, -0.028865, 0.031068, -0.019137), + array<f32, 8>(0.001597, 0.092924, 0.064679, 0.242996, 0.070280, -0.047444, 0.155082, -0.024743), + array<f32, 8>(0.003761, -0.067148, 0.020808, -0.009994, 0.064026, -0.023521, -0.061335, 0.087283), + array<f32, 8>(0.013300, 0.048670, -0.058611, -0.104133, 0.060389, 0.022588, -0.085768, -0.025117), + array<f32, 8>(0.001996, 0.035599, -0.067395, 0.113355, -0.054467, 0.021354, -0.020545, -0.019137), + array<f32, 8>(0.024443, 0.016439, 0.095606, -0.006610, 0.056457, 0.009034, 0.048181, -0.024743), + array<f32, 8>(-0.081707, 0.089380, 0.012570, 0.040154, 0.006970, -0.097259, -0.003088, 0.087283), + array<f32, 8>(0.037347, -0.012520, -0.009110, -0.164514, -0.052337, 0.031441, -0.117828, -0.025117), + array<f32, 8>(-0.050695, 0.023007, -0.086370, 0.106721, -0.022698, -0.063039, 0.007639, -0.019137), + array<f32, 8>(-0.032690, 0.100637, 0.090612, -0.170336, -0.013709, 0.096891, -0.064632, -0.024743), + array<f32, 8>(0.005479, 0.068678, -0.014147, -0.117601, 0.033542, -0.026603, -0.034334, 0.087283), + array<f32, 8>(-0.049645, 0.161140, 0.019592, -0.020424, 0.021700, 0.046387, 0.070111, -0.025117), + array<f32, 8>(-0.075219, -0.030338, -0.042611, 0.045346, -0.012298, -0.029272, -0.048395, -0.019137), + array<f32, 8>(0.110303, 0.091954, 0.026566, -0.013034, -0.001918, 0.025677, -0.003027, -0.024743), + array<f32, 8>(0.084352, 0.004527, 0.042981, 0.040333, 0.011019, 0.011699, 0.053396, 0.087283), + array<f32, 8>(-0.151306, 0.282692, 0.038388, 0.199704, -0.024410, -0.021070, 0.135509, -0.025117), + array<f32, 8>(0.008868, 0.058833, -0.035204, 0.017617, 0.036727, -0.084137, 0.008426, -0.019137), + array<f32, 8>(0.111690, 0.202555, 0.002230, 0.104773, 0.043414, 0.094714, 0.024386, -0.024743), + array<f32, 8>(0.109470, -0.130369, -0.049615, 0.027567, 0.015618, 0.010219, -0.035927, 0.087283), + array<f32, 8>(0.013092, 0.191465, -0.022463, 0.306655, 0.046994, 0.023051, 0.114596, -0.025117), + array<f32, 8>(-0.095580, 0.067644, -0.069810, 0.058185, 0.079298, 0.042359, 0.102818, -0.019137), + array<f32, 8>(0.163902, 0.060505, 0.020250, 0.151637, -0.041346, 0.079968, -0.066609, -0.024743), + array<f32, 8>(0.007401, -0.119463, 0.029195, -0.118251, -0.057537, 0.057136, -0.162722, 0.087283), + array<f32, 8>(-0.036401, 0.152383, -0.049404, 0.188484, 0.069434, -0.056077, -0.041920, -0.025117), + array<f32, 8>(-0.070811, 0.042628, -0.080224, 0.133910, 0.054912, -0.086587, 0.104432, -0.019137), + array<f32, 8>(0.045319, 0.031249, -0.007304, -0.008136, 0.001678, 0.019408, -0.016683, -0.024743), + array<f32, 8>(-0.054316, -0.005207, -0.003794, -0.009173, -0.015797, 0.088869, -0.054766, 0.087283), + array<f32, 8>(0.036646, 0.049626, -0.038869, -0.049720, 0.012847, -0.054911, -0.012426, -0.025117), + array<f32, 8>(-0.002965, 0.087409, -0.027885, 0.089920, 0.013074, -0.106163, 0.065504, -0.019137), + array<f32, 8>(-0.004488, 0.102517, 0.092916, -0.079512, 0.001532, -0.048995, -0.041429, -0.024743), + array<f32, 8>(-0.062161, -0.027813, 0.037159, -0.030745, -0.017068, 0.084630, -0.046134, 0.087283), + array<f32, 8>(-0.017315, 0.191771, -0.050660, -0.140278, 0.038320, 0.037753, -0.043447, -0.025117), + array<f32, 8>(-0.079621, 0.091290, -0.098575, 0.055638, 0.007634, -0.051456, -0.011530, -0.019137), + array<f32, 8>(-0.044260, 0.010435, 0.104869, -0.029082, 0.038487, 0.004167, 0.020321, -0.024743), + array<f32, 8>(0.004107, -0.049898, -0.011912, 0.126974, 0.074958, 0.038876, 0.027066, 0.087283), + array<f32, 8>(0.022312, 0.332216, -0.028889, 0.171475, 0.052267, -0.023821, 0.193472, -0.025117), + array<f32, 8>(0.009104, -0.027289, -0.016718, 0.092231, 0.023904, -0.034162, 0.004693, -0.019137), + array<f32, 8>(0.022922, -0.036846, 0.071670, -0.118853, -0.046374, 0.005972, -0.079006, -0.024743), + array<f32, 8>(-0.086613, -0.033065, 0.032719, 0.081925, -0.025818, -0.065103, 0.010425, 0.087283), + array<f32, 8>(0.014945, 0.330249, -0.062079, 0.408858, 0.044895, -0.036703, 0.195226, -0.025117), + array<f32, 8>(0.021647, 0.086135, -0.013491, 0.027627, -0.033652, -0.016643, -0.037425, -0.019137), + array<f32, 8>(-0.028124, 0.039691, 0.108537, -0.123861, -0.071841, -0.034232, 0.009737, -0.024743), + array<f32, 8>(-0.095938, -0.080740, 0.047554, -0.145590, -0.041365, 0.031658, -0.027601, 0.087283), + array<f32, 8>(-0.050837, 0.179578, 0.020990, 0.240896, -0.038067, 0.007052, 0.036244, -0.025117), + array<f32, 8>(-0.100474, 0.012669, -0.123589, 0.147449, -0.056871, 0.029335, -0.041989, -0.019137), + array<f32, 8>(0.000809, 0.020182, 0.123381, 0.009990, 0.061892, -0.056804, 0.049866, -0.024743), + array<f32, 8>(-0.006123, 0.085572, -0.065080, -0.003607, -0.100605, -0.015746, 0.045932, 0.087283), + array<f32, 8>(-0.068945, 0.037700, -0.068738, 0.088604, 0.034364, -0.027429, -0.023157, -0.025117), + array<f32, 8>(-0.028689, 0.018089, -0.144344, 0.097751, -0.022261, 0.004934, 0.044538, -0.019137), + array<f32, 8>(-0.072695, 0.099329, 0.037965, -0.007148, -0.061809, -0.014461, -0.050644, -0.024743), + array<f32, 8>(-0.043364, -0.019908, 0.033602, -0.011686, -0.046646, -0.005387, 0.057703, 0.087283), + array<f32, 8>(0.020640, 0.058992, 0.042389, -0.111803, -0.000105, -0.069637, -0.058816, -0.025117), + array<f32, 8>(-0.090411, -0.034394, -0.135574, 0.085031, -0.020320, -0.002235, 0.079036, -0.019137), + array<f32, 8>(-0.035238, 0.052656, 0.011918, -0.032684, 0.067555, -0.047663, -0.013151, -0.024743), + array<f32, 8>(0.077223, 0.067583, -0.053024, 0.063017, -0.023909, -0.041936, 0.039041, 0.087283), + array<f32, 8>(-0.011154, 0.253355, 0.006886, 0.066990, -0.018613, -0.033851, 0.022408, -0.025117), + array<f32, 8>(-0.042376, 0.097067, -0.107170, 0.053378, 0.081423, -0.059980, -0.019982, -0.019137), + array<f32, 8>(-0.086462, 0.042703, 0.052655, -0.129460, -0.073930, -0.004732, -0.089001, -0.024743), + array<f32, 8>(0.019294, 0.036932, -0.046783, 0.172396, -0.003345, 0.029704, -0.013067, 0.087283), + array<f32, 8>(0.142370, 0.248269, -0.072705, 0.188676, 0.028917, -0.058974, -0.007950, -0.025117), + array<f32, 8>(-0.021378, 0.064055, -0.103605, -0.015491, -0.002155, -0.048161, -0.045529, -0.019137), + array<f32, 8>(0.006191, 0.063159, 0.005143, -0.101334, -0.020484, 0.038330, 0.010742, -0.024743), + array<f32, 8>(-0.123413, 0.027806, -0.063111, 0.060050, -0.087346, 0.080827, 0.016499, 0.087283), + array<f32, 8>(0.054552, 0.047349, 0.029259, 0.152502, -0.013689, -0.035447, -0.006584, -0.025117), + array<f32, 8>(-0.034984, 0.059972, -0.147872, 0.096835, 0.055766, -0.001973, -0.033631, -0.019137), + array<f32, 8>(0.004488, -0.060204, 0.120817, -0.095007, 0.040546, 0.026207, -0.011824, -0.024743), + array<f32, 8>(0.000380, 0.102988, 0.010112, -0.011668, 0.004855, -0.019988, -0.035633, 0.087283), + array<f32, 8>(0.003894, -0.083172, -0.046051, -0.005485, 0.017347, -0.057191, -0.085077, -0.025117), + array<f32, 8>(-0.066185, 0.092341, -0.135679, 0.009092, -0.015954, 0.003226, -0.010182, -0.019137) +); + +const weights_layer2: array<array<f32, 8>, 9> = array( + array<f32, 8>(0.071600, -0.118269, 0.093769, 0.096974, -0.002193, -0.065924, -0.125094, 0.018248), + array<f32, 8>(-0.089131, -0.053007, 0.150626, -0.051485, 0.087371, -0.078030, -0.045468, 0.018248), + array<f32, 8>(0.042144, 0.146191, 0.152445, 0.028572, 0.064491, -0.061860, 0.037828, 0.018248), + array<f32, 8>(-0.084747, -0.133062, -0.030736, 0.061174, -0.055809, -0.012031, 0.126923, 0.018248), + array<f32, 8>(-0.017155, -0.105189, 0.003457, 0.105491, 0.003587, 0.089110, -0.001623, 0.018248), + array<f32, 8>(-0.028012, -0.066691, 0.125358, -0.027705, 0.032134, 0.044475, -0.036991, 0.018248), + array<f32, 8>(0.094536, -0.038367, -0.009421, 0.027049, -0.103427, -0.065209, -0.110071, 0.018248), + array<f32, 8>(0.147956, 0.028446, 0.031066, 0.055667, -0.039952, 0.069251, 0.020060, 0.018248), + array<f32, 8>(0.067507, 0.154407, -0.017526, 0.064009, -0.014328, 0.022175, 0.015376, 0.018248) ); -const bias_layer0 = vec4<f32>(0.0, 0.0, 0.0, 0.0); |
