diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 17:37:01 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 17:37:01 +0100 |
| commit | f3c7ef8cd612f5ac908f39310c4c11566879313f (patch) | |
| tree | 1e66127a855f30282c852731c0dd88ae6c7039bc /workspaces/main/shaders/cnn | |
| parent | 0aa35e895d70f4535b7fac0f5df318888a6847dc (diff) | |
fix: Support variable kernel sizes in CNN layer generation
Training script was hardcoded to generate cnn_conv3x3_* calls regardless
of actual kernel size, causing shader validation errors when layer 1 used
5×5 kernel (100 weights) but called 3×3 function (expected 36).
Changes:
- train_cnn.py: Generate correct conv function based on kernel_sizes[i]
- cnn_conv5x5.wgsl: Add cnn_conv5x5_7to4 and cnn_conv5x5_7to1 variants
- Regenerate cnn_layer.wgsl with correct function calls for [3,5,3]
- Document kernel size→function mapping in HOWTO.md
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'workspaces/main/shaders/cnn')
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_conv5x5.wgsl | 86 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_layer.wgsl | 30 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn/cnn_weights_generated.wgsl | 324 |
3 files changed, 254 insertions, 186 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl index bd9abfa..15eaf96 100644 --- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl +++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl @@ -51,3 +51,89 @@ fn cnn_conv5x5_with_coord( return sum; } + +// 5×5 variant for 7→4 channels (RGBD output) +// weights: array<array<f32, 8>, 100> (25 positions × 4 channels, each with 7 weights + bias) +fn cnn_conv5x5_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + original: vec4<f32>, + weights: array<array<f32, 8>, 100> +) -> vec4<f32> { + let step = 1.0 / resolution; + + let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = (gray_01 - 0.5) * 2.0; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -2; dy <= 2; dy++) { + for (var dx = -2; dx <= 2; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd_01 = textureSample(tex, samp, uv + offset); + let rgbd = (rgbd_01 - 0.5) * 2.0; + + let inputs = array<f32, 7>( + rgbd.r, rgbd.g, rgbd.b, rgbd.a, + uv_norm.x, uv_norm.y, gray + ); + + for (var out_c = 0; out_c < 4; out_c++) { + let idx = pos * 4 + out_c; + var channel_sum = weights[idx][7]; + for (var in_c = 0; in_c < 7; in_c++) { + channel_sum += weights[idx][in_c] * inputs[in_c]; + } + sum[out_c] += channel_sum; + } + pos++; + } + } + + return sum; +} + +// 5×5 variant for 7→1 channel (scalar output) +// weights: array<array<f32, 8>, 25> (25 positions, each with 7 weights + bias) +fn cnn_conv5x5_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + original: vec4<f32>, + weights: array<array<f32, 8>, 25> +) -> f32 { + let step = 1.0 / resolution; + + let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b; + let gray = (gray_01 - 0.5) * 2.0; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = 0.0; + var pos = 0; + + for (var dy = -2; dy <= 2; dy++) { + for (var dx = -2; dx <= 2; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd_01 = textureSample(tex, samp, uv + offset); + let rgbd = (rgbd_01 - 0.5) * 2.0; + + sum += weights[pos][0] * rgbd.r; + sum += weights[pos][1] * rgbd.g; + sum += weights[pos][2] * rgbd.b; + sum += weights[pos][3] * rgbd.a; + sum += weights[pos][4] * uv_norm.x; + sum += weights[pos][5] * uv_norm.y; + sum += weights[pos][6] * gray; + sum += weights[pos][7]; // Bias + + pos++; + } + } + + return sum; +} diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl index 5834f78..fad283c 100644 --- a/workspaces/main/shaders/cnn/cnn_layer.wgsl +++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl @@ -8,6 +8,7 @@ #include "common_uniforms" #include "cnn_activation" #include "cnn_conv3x3" +#include "cnn_conv5x5" #include "cnn_weights_generated" struct CNNLayerParams { @@ -33,24 +34,33 @@ struct CNNLayerParams { let original = textureSample(original_input, smplr, uv); var result = vec4<f32>(0.0); - // Layer 0 uses coordinate-aware convolution + // Layer 0: 7→4 (RGBD output) if (params.layer_index == 0) { - result = cnn_conv3x3_with_coord(txt, smplr, uv, uniforms.resolution, - rgba_weights_layer0, coord_weights_layer0, bias_layer0); - result = cnn_tanh(result); + result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution, + original, weights_layer0); + result = cnn_tanh(result); // Output in [-1,1] + // Denormalize to [0,1] for texture storage + result = (result + 1.0) * 0.5; } else if (params.layer_index == 1) { - result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, - weights_layer1, bias_layer1); - result = cnn_tanh(result); + result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution, + original, weights_layer1); + result = cnn_tanh(result); // Output in [-1,1] + // Denormalize to [0,1] for texture storage + result = (result + 1.0) * 0.5; } else if (params.layer_index == 2) { - result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution, - weights_layer2, bias_layer2); + let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution, + original, weights_layer2); + // Denormalize from [-1,1] to [0,1] + let gray_01 = (gray_out + 1.0) * 0.5; + result = vec4<f32>(gray_01, gray_01, gray_01, 1.0); // Expand to RGB } else { result = input; } - return mix(original, result, params.blend_amount); + // Blend with ORIGINAL input from layer 0 +return original; +// return mix(original, result, params.blend_amount); } diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl index 6052ac5..6ec78c1 100644 --- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl +++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl @@ -1,185 +1,157 @@ // Auto-generated CNN weights // DO NOT EDIT - Generated by train_cnn.py -const rgba_weights_layer0: array<mat4x4<f32>, 9> = array( - mat4x4<f32>( - -0.181929, -0.244329, -0.354404, 0.0, - -0.291597, -0.195653, 0.081896, 0.0, - 0.081595, 0.164081, -0.236318, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 0.731888, 0.717648, 0.524081, 0.0, - -0.029760, -0.208000, 0.008438, 0.0, - 0.442082, 0.354681, 0.049288, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.623141, -0.695759, -0.087885, 0.0, - 0.043135, 0.071979, 0.213065, 0.0, - 0.011581, 0.110995, 0.034100, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 0.170016, 0.188298, 0.134083, 0.0, - -0.222954, -0.088011, 0.015668, 0.0, - 0.921836, 0.437158, 0.061577, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 1.431940, 1.148113, 1.238067, 0.0, - -0.212535, 0.366860, 0.320956, 0.0, - 0.771192, 0.765570, 0.029189, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 0.171088, 0.000155, 0.212552, 0.0, - 0.029536, 0.447892, 0.041381, 0.0, - 0.011807, -0.167281, -0.200702, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.668151, -0.813927, -0.132108, 0.0, - -0.156250, 0.179112, -0.069585, 0.0, - 0.403347, 0.482877, 0.182611, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.609871, -0.768480, -0.590538, 0.0, - -0.171854, 0.150167, 0.105694, 0.0, - -0.059052, 0.066999, -0.244222, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.112983, -0.066299, 0.117696, 0.0, - -0.172541, 0.095008, -0.160754, 0.0, - -0.369667, -0.000628, 0.163602, 0.0, - 0.0, 0.0, 0.0, 0.0, - ) +const weights_layer0: array<array<f32, 8>, 36> = array( + array<f32, 8>(0.181071, 0.180671, -0.156688, -0.133398, -0.121571, -0.069398, 0.197448, -0.010528), + array<f32, 8>(-0.315179, -0.222857, -0.084049, 0.211023, 0.093378, 0.041467, -0.171730, 0.109394), + array<f32, 8>(-0.776134, -0.748738, -0.426878, 0.038600, 0.055702, 0.264617, -0.766478, -0.140384), + array<f32, 8>(0.325726, 0.142498, 0.022606, -0.071913, -0.085795, -0.008587, 0.142883, 0.080188), + array<f32, 8>(0.036004, 0.238142, 0.248112, -0.031285, -0.124107, 0.040947, 0.298797, -0.010528), + array<f32, 8>(0.223016, 0.358479, -0.011029, 0.102233, -0.167285, 0.111095, 0.247539, 0.109394), + array<f32, 8>(1.147701, 1.132161, 0.406157, 0.030399, 0.155895, 0.158795, 1.070009, -0.140384), + array<f32, 8>(0.343120, 0.413393, 0.345993, 0.048482, 0.007208, 0.103310, 0.410492, 0.080188), + array<f32, 8>(-0.074602, -0.271241, -0.181283, -0.062921, 0.126872, -0.006445, -0.171823, -0.010528), + array<f32, 8>(-0.554038, -0.443533, -0.282759, 0.123052, -0.076625, 0.047519, -0.525091, 0.109394), + array<f32, 8>(-0.036767, 0.183663, 0.026803, -0.102403, -0.052118, 0.014384, 0.272146, -0.140384), + array<f32, 8>(0.006535, -0.016445, -0.063138, 0.121560, 0.026330, 0.142604, -0.172336, 0.080188), + array<f32, 8>(-0.166300, -0.154177, 0.124821, -0.109807, -0.052216, 0.009249, -0.217568, -0.010528), + array<f32, 8>(0.258850, 0.122862, -0.085148, 0.217309, 0.115453, -0.045434, -0.018513, 0.109394), + array<f32, 8>(0.030132, 0.150035, 0.140258, -0.016209, -0.002259, -0.073409, 0.106627, -0.140384), + array<f32, 8>(-0.382104, -0.406515, -0.130833, 0.006075, -0.032734, -0.066618, -0.293135, 0.080188), + array<f32, 8>(-0.344131, -0.804046, -0.131275, -0.025023, 0.107656, -0.105495, -0.477465, -0.010528), + array<f32, 8>(2.238629, 2.666346, 1.344841, 0.188894, -0.064351, -0.079447, 2.525197, 0.109394), + array<f32, 8>(1.617750, 1.332233, 0.817712, 0.084650, 0.008433, -0.191120, 1.499900, -0.140384), + array<f32, 8>(-1.007007, -0.971186, -0.670325, 0.066275, 0.044722, -0.034290, -0.874401, 0.080188), + array<f32, 8>(0.214410, 0.423019, 0.210104, -0.111646, 0.025179, 0.021650, 0.370431, -0.010528), + array<f32, 8>(0.698145, 0.742521, 0.259904, 0.172190, 0.002103, 0.150391, 0.710542, 0.109394), + array<f32, 8>(-0.437137, -0.378108, -0.350803, -0.151723, -0.079989, -0.014077, -0.393804, -0.140384), + array<f32, 8>(0.148036, 0.149877, 0.208686, 0.118023, 0.106081, -0.156695, 0.040117, 0.080188), + array<f32, 8>(0.007790, -0.006625, -0.036656, 0.033389, -0.087019, 0.037006, 0.125953, -0.010528), + array<f32, 8>(-0.362541, -0.324086, -0.150862, 0.000339, 0.227779, -0.035100, -0.325073, 0.109394), + array<f32, 8>(-0.248297, -0.083314, -0.071273, -0.108375, 0.037361, -0.055978, -0.192155, -0.140384), + array<f32, 8>(0.396704, 0.247469, 0.274000, -0.005654, 0.014277, 0.097346, 0.185350, 0.080188), + array<f32, 8>(-0.264719, -0.327878, -0.137154, -0.082728, 0.053555, 0.134324, -0.360542, -0.010528), + array<f32, 8>(0.154692, 0.521805, -0.014565, 0.236786, -0.050656, -0.077934, 0.252781, 0.109394), + array<f32, 8>(-0.762319, -0.789229, -0.565690, -0.050543, -0.033753, 0.065576, -0.704560, -0.140384), + array<f32, 8>(-0.366925, -0.189991, -0.177319, -0.068166, -0.096794, 0.111543, -0.319584, 0.080188), + array<f32, 8>(-0.069582, -0.007117, -0.080914, -0.054291, 0.099181, -0.057500, 0.064654, -0.010528), + array<f32, 8>(-0.234008, -0.283279, -0.317393, 0.138485, -0.089743, -0.116539, -0.340335, 0.109394), + array<f32, 8>(-0.332349, -0.045645, -0.142621, 0.066072, -0.063394, -0.047263, -0.275151, -0.140384), + array<f32, 8>(0.243287, -0.037305, 0.084462, 0.050442, 0.038222, -0.094055, 0.019547, 0.080188) ); -const coord_weights_layer0 = mat2x4<f32>( - 0.059076, -0.026617, -0.005155, 0.0, - 0.135407, -0.090329, 0.058216, 0.0 +const weights_layer1: array<array<f32, 8>, 100> = array( + array<f32, 8>(0.036741, 0.174566, 0.291795, -0.033066, -0.084725, 0.017211, -0.119127, 0.705181), + array<f32, 8>(0.122186, -0.078751, 0.003745, 0.000022, 0.027259, -0.014794, -0.132560, -0.554858), + array<f32, 8>(0.275572, 0.018150, -0.008554, -0.146253, -0.124261, -0.106076, -0.124406, -0.384431), + array<f32, 8>(0.235391, -0.078051, 0.052366, 0.054055, 0.116088, 0.137109, -0.041252, 0.096202), + array<f32, 8>(0.029752, 0.005636, -0.434780, 0.218587, -0.174135, 0.077972, -0.243145, 0.705181), + array<f32, 8>(0.075832, 0.036187, -0.151527, -0.131852, 0.042588, -0.087969, 0.116469, -0.554858), + array<f32, 8>(0.234748, -0.130623, 0.049978, -0.065811, 0.114875, -0.159655, -0.065553, -0.384431), + array<f32, 8>(0.000944, -0.056210, 0.007023, -0.046152, 0.006291, 0.055730, 0.079512, 0.096202), + array<f32, 8>(-0.656454, 0.171425, 0.518928, -0.355405, 0.005253, 0.188539, 0.031148, 0.705181), + array<f32, 8>(0.253326, -0.090369, -0.096764, -0.047542, 0.073904, 0.040503, 0.111922, -0.554858), + array<f32, 8>(0.228924, 0.111527, -0.027116, -0.133119, 0.099857, -0.150050, -0.075578, -0.384431), + array<f32, 8>(-0.037827, 0.075674, -0.111529, -0.114302, -0.118413, -0.000674, 0.072307, 0.096202), + array<f32, 8>(-0.633403, -0.158435, 0.042044, -0.199189, 0.135613, 0.128976, -0.095854, 0.705181), + array<f32, 8>(0.224475, 0.009348, -0.014853, -0.083097, -0.129013, 0.058030, -0.010732, -0.554858), + array<f32, 8>(0.277777, -0.008638, -0.024935, 0.055844, -0.137042, -0.173785, -0.168956, -0.384431), + array<f32, 8>(-0.027801, 0.013278, -0.084694, -0.074887, 0.065917, 0.209263, 0.102990, 0.096202), + array<f32, 8>(-0.236525, 0.081545, 0.120326, 0.497596, -0.133167, 0.111884, -0.149371, 0.705181), + array<f32, 8>(0.041858, -0.011563, 0.027560, 0.003061, -0.033285, -0.083021, -0.297831, -0.554858), + array<f32, 8>(0.297747, 0.001253, 0.106780, 0.007335, 0.259682, -0.229853, -0.001358, -0.384431), + array<f32, 8>(-0.043069, -0.045672, -0.018898, -0.190785, -0.000672, 0.089811, 0.118545, 0.096202), + array<f32, 8>(-0.066188, 0.120479, -0.198253, -0.004995, 0.017530, -0.258193, -0.046440, 0.705181), + array<f32, 8>(-0.017375, 0.127093, -0.003759, 0.053924, 0.036127, 0.462423, 0.238701, -0.554858), + array<f32, 8>(0.281321, -0.034091, 0.129930, 0.026051, -0.134687, -0.102370, 0.015789, -0.384431), + array<f32, 8>(-0.098730, 0.061805, 0.049945, 0.045479, -0.037703, 0.164979, 0.015898, 0.096202), + array<f32, 8>(-0.782371, -0.398567, 0.203990, -0.918743, -0.150142, -0.208263, -0.191751, 0.705181), + array<f32, 8>(0.020808, -0.079593, -0.020436, 0.289221, 0.055709, 0.460382, 0.367026, -0.554858), + array<f32, 8>(0.349494, -0.057511, 0.075111, -0.102943, 0.063262, -0.130990, -0.008248, -0.384431), + array<f32, 8>(0.168912, 0.176731, -0.134468, 0.232372, -0.033147, 0.085774, -0.048677, 0.096202), + array<f32, 8>(0.618180, -0.520695, -0.530039, 1.196123, 0.122213, -0.173562, -0.324433, 0.705181), + array<f32, 8>(0.638450, -0.149078, -0.017734, 0.493464, -0.015998, 0.584648, 0.288746, -0.554858), + array<f32, 8>(0.074488, 0.212364, 0.187959, -0.055154, 0.065607, -0.068925, -0.017994, -0.384431), + array<f32, 8>(-0.018923, -0.099117, -0.005060, 0.242744, -0.274629, 0.051926, 0.079712, 0.096202), + array<f32, 8>(0.658693, -0.039718, 0.441997, 0.250902, 0.016350, -0.200507, 0.022444, 0.705181), + array<f32, 8>(0.559097, -0.196625, 0.035123, 0.046925, -0.186840, 0.619048, 0.559590, -0.554858), + array<f32, 8>(-0.139655, -0.021050, 0.295279, -0.302983, -0.061179, -0.146221, 0.082147, -0.384431), + array<f32, 8>(0.174535, -0.052728, -0.181830, -0.037692, -0.027964, 0.094183, 0.075072, 0.096202), + array<f32, 8>(-0.025873, -0.035599, 0.024004, 0.191656, -0.102396, -0.254495, -0.094824, 0.705181), + array<f32, 8>(0.228900, -0.003296, -0.008463, 0.324218, -0.049343, 0.447783, 0.273508, -0.554858), + array<f32, 8>(0.160779, -0.008924, -0.081788, -0.433376, 0.227781, -0.062434, 0.032769, -0.384431), + array<f32, 8>(0.278581, -0.029578, -0.048828, 0.037354, -0.083503, 0.051478, 0.129985, 0.096202), + array<f32, 8>(-0.179409, -0.150093, -0.154602, 0.155942, -0.030033, -0.052753, 0.126435, 0.705181), + array<f32, 8>(-0.276681, 0.000117, 0.005694, 0.112600, 0.053831, -0.226162, 0.238592, -0.554858), + array<f32, 8>(0.509299, -0.073816, -0.036746, -0.229020, -0.084083, 0.102554, 0.091619, -0.384431), + array<f32, 8>(-0.226426, 0.100353, 0.041059, 0.298785, 0.189457, -0.066325, -0.044970, 0.096202), + array<f32, 8>(-1.231526, -1.330216, 0.526922, -0.587336, -0.130186, -0.092712, 0.001462, 0.705181), + array<f32, 8>(-0.181467, -0.162607, -0.034778, 0.462190, 0.018091, -0.425409, -0.281373, -0.554858), + array<f32, 8>(-0.127841, 0.636898, -0.133350, -0.021139, 0.041704, 0.131472, 0.064884, -0.384431), + array<f32, 8>(-0.142875, 0.037216, -0.253924, 0.065928, 0.010346, -0.066305, 0.010009, 0.096202), + array<f32, 8>(1.275860, -1.926775, -2.786324, 0.841585, 0.184752, 0.129195, -0.746670, 0.705181), + array<f32, 8>(0.224131, -0.438348, -0.301076, 0.516342, -0.025436, -0.288755, 0.113458, -0.554858), + array<f32, 8>(-0.458493, 1.383449, -0.836509, 0.513842, 0.009788, 0.140230, -0.178303, -0.384431), + array<f32, 8>(0.615726, -0.780623, -0.056982, 0.279193, -0.126452, -0.148191, -0.142928, 0.096202), + array<f32, 8>(0.900523, -0.634242, 0.387372, 0.906412, 0.123193, 0.017453, 0.044746, 0.705181), + array<f32, 8>(0.027988, -0.101389, 0.206885, 0.397388, -0.107841, -0.200823, 0.062329, -0.554858), + array<f32, 8>(-0.075820, 0.221546, -0.360333, 0.653466, -0.164603, -0.018184, -0.054492, -0.384431), + array<f32, 8>(0.428530, -0.333444, 0.292398, -0.419585, 0.189928, 0.033952, 0.048148, 0.096202), + array<f32, 8>(-0.375948, -0.074633, 0.091707, 0.403029, -0.058652, 0.015809, -0.074391, 0.705181), + array<f32, 8>(0.257007, -0.003856, 0.044474, 0.471313, -0.028859, -0.319677, 0.420764, -0.554858), + array<f32, 8>(0.134231, -0.048700, 0.029223, -0.006709, 0.116888, 0.081077, 0.034694, -0.384431), + array<f32, 8>(-0.081323, 0.073446, -0.045513, 0.067464, 0.129549, -0.099304, 0.013744, 0.096202), + array<f32, 8>(-0.050969, -0.016310, 0.185193, 0.154125, -0.036880, 0.023559, 0.025859, 0.705181), + array<f32, 8>(0.126071, -0.051684, 0.043923, 0.108173, 0.066038, -0.192287, 0.352096, -0.554858), + array<f32, 8>(0.427852, 0.008965, 0.102399, -0.108640, -0.150961, 0.163040, -0.050800, -0.384431), + array<f32, 8>(-0.072433, 0.037114, 0.013594, 0.069869, 0.104687, -0.078728, -0.069972, 0.096202), + array<f32, 8>(0.361002, -0.392113, 0.274024, -0.034550, -0.082394, 0.095107, 0.401706, 0.705181), + array<f32, 8>(-0.348729, 0.008971, -0.079949, -0.066546, 0.134612, -0.303610, -0.186502, -0.554858), + array<f32, 8>(0.345962, 0.668520, -0.784279, 0.051949, -0.042057, 0.050625, 0.020689, -0.384431), + array<f32, 8>(-0.274767, 0.076023, -0.254009, 0.323161, 0.033120, -0.065693, 0.092788, 0.096202), + array<f32, 8>(-0.922054, -0.756431, 0.393524, -1.602483, 0.114452, 0.112853, 0.537923, 0.705181), + array<f32, 8>(-0.171843, 0.212152, -0.122848, 0.602091, 0.046238, -0.188875, -0.437717, -0.554858), + array<f32, 8>(-0.588119, 1.183603, -0.530633, -0.829767, -0.006702, 0.078743, 0.153854, -0.384431), + array<f32, 8>(-0.188030, -0.172885, -0.066943, -0.046181, -0.106921, -0.134572, 0.080975, 0.096202), + array<f32, 8>(-1.695124, 0.212725, 1.229505, -0.741534, 0.126303, 0.133647, 0.108956, 0.705181), + array<f32, 8>(-0.320230, -0.041878, -0.084837, 0.395743, -0.002644, -0.220269, 0.159212, -0.554858), + array<f32, 8>(-0.321301, 0.390960, 0.261866, -0.413745, -0.124661, 0.036689, 0.011903, -0.384431), + array<f32, 8>(0.119845, -0.057012, 0.525753, -0.132904, 0.033117, 0.029606, 0.127306, 0.096202), + array<f32, 8>(-0.612274, 0.254915, 0.019614, 0.080070, 0.006519, 0.087940, -0.047548, 0.705181), + array<f32, 8>(0.051014, -0.017890, 0.233801, -0.168954, 0.079934, -0.283275, 0.332838, -0.554858), + array<f32, 8>(0.332870, -0.118202, 0.131280, 0.036194, 0.131048, 0.067103, 0.056030, -0.384431), + array<f32, 8>(-0.257300, 0.003565, -0.232427, 0.046336, -0.003665, -0.094940, 0.127497, 0.096202), + array<f32, 8>(-0.600203, 0.063626, -0.105803, 0.208178, -0.051062, -0.008572, -0.035231, 0.705181), + array<f32, 8>(-0.042070, -0.023061, -0.037054, -0.002751, 0.019568, 0.078162, -0.044403, -0.554858), + array<f32, 8>(-0.079765, 0.198893, 0.041831, 0.061495, -0.120350, 0.150103, -0.058056, -0.384431), + array<f32, 8>(0.139291, -0.054425, 0.002238, 0.016114, 0.057097, -0.116833, -0.081744, 0.096202), + array<f32, 8>(-0.738361, -0.374730, 0.359610, 0.291041, -0.072061, -0.041352, -0.189793, 0.705181), + array<f32, 8>(0.090555, -0.020056, -0.012717, 0.120644, -0.026834, -0.073251, 0.329796, -0.554858), + array<f32, 8>(0.367344, -0.049545, -0.421975, -0.066803, 0.069864, 0.155396, -0.118779, -0.384431), + array<f32, 8>(0.045037, 0.104654, -0.079043, 0.036356, 0.026634, 0.000503, 0.029131, 0.096202), + array<f32, 8>(-0.250248, -0.186821, 0.671967, 0.211267, 0.217392, -0.004840, -0.025819, 0.705181), + array<f32, 8>(-0.085968, -0.124825, -0.065762, 0.084599, -0.049381, -0.046614, 0.102548, -0.554858), + array<f32, 8>(0.691882, 0.082284, 0.786431, -0.352507, 0.007640, 0.130147, 0.079168, -0.384431), + array<f32, 8>(-0.089554, 0.033068, -0.013298, 0.223259, -0.209770, -0.050661, -0.024264, 0.096202), + array<f32, 8>(-0.214123, -0.168090, -1.084093, -0.058018, 0.160565, -0.012314, 0.023839, 0.705181), + array<f32, 8>(-0.078129, -0.058517, 0.103285, -0.001551, -0.050279, 0.056731, 0.340157, -0.554858), + array<f32, 8>(0.540997, 0.024887, -0.143659, -0.515214, -0.222923, 0.116850, 0.008144, -0.384431), + array<f32, 8>(-0.257782, 0.012741, 0.142405, 0.064579, 0.010678, -0.069590, -0.029995, 0.096202), + array<f32, 8>(-0.153197, -0.103723, 0.189375, -0.136753, 0.037238, -0.016964, 0.086707, 0.705181), + array<f32, 8>(0.168810, 0.040280, -0.134904, -0.028552, -0.059283, -0.131581, 0.522745, -0.554858), + array<f32, 8>(0.450779, 0.116955, -0.198275, -0.132313, 0.138815, 0.283598, 0.010077, -0.384431), + array<f32, 8>(-0.010634, -0.029003, -0.015055, 0.000537, -0.000948, -0.070859, 0.027198, 0.096202) ); -const bias_layer0 = vec4<f32>(-0.526177, -0.569862, -1.370040, 0.0); - -const weights_layer1: array<mat4x4<f32>, 9> = array( - mat4x4<f32>( - 0.180029, -1.107249, 0.570741, 0.0, - -0.098536, 0.079545, -0.083257, 0.0, - -0.020066, 0.333084, 0.039506, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 3.068946, -1.783570, -0.550517, 0.0, - -0.296369, -0.080958, 0.040260, 0.0, - -0.093713, -0.212577, -0.110011, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 2.282564, -0.538192, -0.793214, 0.0, - -0.395788, 0.130881, 0.078571, 0.0, - -0.041375, 0.061666, 0.045651, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.267284, -1.971639, -0.099616, 0.0, - -0.084432, 0.139794, 0.007091, 0.0, - -0.103042, -0.104340, 0.067299, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -5.233469, -2.252747, -3.555217, 0.0, - 0.647940, -0.178858, 0.351633, 0.0, - -0.014237, -0.505881, 0.165940, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.121700, -0.677386, -2.435040, 0.0, - 0.084806, -0.028000, 0.380387, 0.0, - -0.020906, -0.279161, 0.041915, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 2.982562, -0.298441, -0.147775, 0.0, - -0.291832, 0.102875, -0.128590, 0.0, - -0.091786, 0.104389, -0.188678, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -4.434978, -0.261830, -2.436411, 0.0, - 0.349188, -0.245908, 0.272592, 0.0, - 0.010322, -0.148525, -0.031531, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - 0.129886, 1.516168, -0.755576, 0.0, - 0.133138, -0.260276, 0.028059, 0.0, - 0.001185, 0.141547, -0.003606, 0.0, - 0.0, 0.0, 0.0, 0.0, - ) +const weights_layer2: array<array<f32, 8>, 9> = array( + array<f32, 8>(-0.049127, 0.044294, 0.110612, -0.009803, -0.027652, -0.198405, -0.024763, 0.092761), + array<f32, 8>(-0.104378, 0.082831, 0.347193, -0.040960, -0.067023, -0.078163, 0.003278, 0.092761), + array<f32, 8>(-0.044334, 0.015010, 0.085171, -0.000658, -0.048960, -0.180710, -0.118388, 0.092761), + array<f32, 8>(-0.090802, 0.239246, 0.060319, -0.157119, -0.096438, 0.122192, -0.110975, 0.092761), + array<f32, 8>(-0.243697, 0.489006, 0.108902, -0.465329, 0.124279, 0.126061, 0.159364, 0.092761), + array<f32, 8>(-0.102790, 0.124908, 0.017353, -0.057910, 0.113450, 0.203798, 0.088674, 0.092761), + array<f32, 8>(-0.038260, 0.080576, 0.018707, -0.018687, -0.003492, 0.084815, 0.019988, 0.092761), + array<f32, 8>(-0.048685, 0.067695, 0.012956, 0.020401, -0.090458, -0.066818, -0.060523, 0.092761), + array<f32, 8>(-0.060803, 0.066513, 0.004983, -0.006159, 0.095386, -0.025576, 0.056029, 0.092761) ); -const bias_layer1 = vec4<f32>(1.367986, -1.148709, -0.650040, 0.0); - -const weights_layer2: array<mat4x4<f32>, 9> = array( - mat4x4<f32>( - -0.137003, -0.289376, 0.625000, 0.0, - -0.120120, -0.238968, 0.448432, 0.0, - -0.142094, -0.253706, 0.458181, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.337017, -0.757585, 0.135953, 0.0, - -0.304432, -0.553491, 0.419907, 0.0, - -0.313585, -0.467667, 0.615326, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.161089, -0.328735, 0.612679, 0.0, - -0.137144, -0.172882, 0.176362, 0.0, - -0.153195, -0.061571, 0.173977, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.227814, -0.544193, -0.564658, 0.0, - -0.211743, -0.430586, 0.080349, 0.0, - -0.214442, -0.417501, 0.880266, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.435370, -0.295169, -0.865976, 0.0, - -0.423147, -0.274780, 0.323049, 0.0, - -0.411180, -0.062517, 1.099769, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.199573, -0.488030, -0.396440, 0.0, - -0.187844, -0.360516, -0.156646, 0.0, - -0.188681, -0.292304, -0.134645, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.123218, -0.287990, 0.154656, 0.0, - -0.112954, -0.282778, 0.498742, 0.0, - -0.139083, -0.319337, 1.112621, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.267477, -0.691374, -0.028960, 0.0, - -0.246348, -0.585583, 0.401194, 0.0, - -0.253279, -0.562875, 1.105818, 0.0, - 0.0, 0.0, 0.0, 0.0, - ), - mat4x4<f32>( - -0.083133, -0.131627, 0.460039, 0.0, - -0.071126, -0.108601, 0.163545, 0.0, - -0.092579, -0.110020, 0.131282, 0.0, - 0.0, 0.0, 0.0, 0.0, - ) -); - -const bias_layer2 = vec4<f32>(-1.805686, -0.798340, 0.462318, 0.0); - |
