summaryrefslogtreecommitdiff
path: root/workspaces
diff options
context:
space:
mode:
Diffstat (limited to 'workspaces')
-rw-r--r--workspaces/main/shaders/cnn/cnn_conv5x5.wgsl86
-rw-r--r--workspaces/main/shaders/cnn/cnn_layer.wgsl30
-rw-r--r--workspaces/main/shaders/cnn/cnn_weights_generated.wgsl324
3 files changed, 254 insertions, 186 deletions
diff --git a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
index bd9abfa..15eaf96 100644
--- a/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
@@ -51,3 +51,89 @@ fn cnn_conv5x5_with_coord(
return sum;
}
+
+// 5×5 variant for 7→4 channels (RGBD output)
+// weights: array<array<f32, 8>, 100> (25 positions × 4 channels, each with 7 weights + bias)
+fn cnn_conv5x5_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ original: vec4<f32>,
+ weights: array<array<f32, 8>, 100>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;
+ let gray = (gray_01 - 0.5) * 2.0;
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd_01 = textureSample(tex, samp, uv + offset);
+ let rgbd = (rgbd_01 - 0.5) * 2.0;
+
+ let inputs = array<f32, 7>(
+ rgbd.r, rgbd.g, rgbd.b, rgbd.a,
+ uv_norm.x, uv_norm.y, gray
+ );
+
+ for (var out_c = 0; out_c < 4; out_c++) {
+ let idx = pos * 4 + out_c;
+ var channel_sum = weights[idx][7];
+ for (var in_c = 0; in_c < 7; in_c++) {
+ channel_sum += weights[idx][in_c] * inputs[in_c];
+ }
+ sum[out_c] += channel_sum;
+ }
+ pos++;
+ }
+ }
+
+ return sum;
+}
+
+// 5×5 variant for 7→1 channel (scalar output)
+// weights: array<array<f32, 8>, 25> (25 positions, each with 7 weights + bias)
+fn cnn_conv5x5_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ original: vec4<f32>,
+ weights: array<array<f32, 8>, 25>
+) -> f32 {
+ let step = 1.0 / resolution;
+
+ let gray_01 = 0.2126*original.r + 0.7152*original.g + 0.0722*original.b;
+ let gray = (gray_01 - 0.5) * 2.0;
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = 0.0;
+ var pos = 0;
+
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd_01 = textureSample(tex, samp, uv + offset);
+ let rgbd = (rgbd_01 - 0.5) * 2.0;
+
+ sum += weights[pos][0] * rgbd.r;
+ sum += weights[pos][1] * rgbd.g;
+ sum += weights[pos][2] * rgbd.b;
+ sum += weights[pos][3] * rgbd.a;
+ sum += weights[pos][4] * uv_norm.x;
+ sum += weights[pos][5] * uv_norm.y;
+ sum += weights[pos][6] * gray;
+ sum += weights[pos][7]; // Bias
+
+ pos++;
+ }
+ }
+
+ return sum;
+}
diff --git a/workspaces/main/shaders/cnn/cnn_layer.wgsl b/workspaces/main/shaders/cnn/cnn_layer.wgsl
index 5834f78..fad283c 100644
--- a/workspaces/main/shaders/cnn/cnn_layer.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_layer.wgsl
@@ -8,6 +8,7 @@
#include "common_uniforms"
#include "cnn_activation"
#include "cnn_conv3x3"
+#include "cnn_conv5x5"
#include "cnn_weights_generated"
struct CNNLayerParams {
@@ -33,24 +34,33 @@ struct CNNLayerParams {
let original = textureSample(original_input, smplr, uv);
var result = vec4<f32>(0.0);
- // Layer 0 uses coordinate-aware convolution
+ // Layer 0: 7→4 (RGBD output)
if (params.layer_index == 0) {
- result = cnn_conv3x3_with_coord(txt, smplr, uv, uniforms.resolution,
- rgba_weights_layer0, coord_weights_layer0, bias_layer0);
- result = cnn_tanh(result);
+ result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution,
+ original, weights_layer0);
+ result = cnn_tanh(result); // Output in [-1,1]
+ // Denormalize to [0,1] for texture storage
+ result = (result + 1.0) * 0.5;
}
else if (params.layer_index == 1) {
- result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution,
- weights_layer1, bias_layer1);
- result = cnn_tanh(result);
+ result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution,
+ original, weights_layer1);
+ result = cnn_tanh(result); // Output in [-1,1]
+ // Denormalize to [0,1] for texture storage
+ result = (result + 1.0) * 0.5;
}
else if (params.layer_index == 2) {
- result = cnn_conv3x3(txt, smplr, uv, uniforms.resolution,
- weights_layer2, bias_layer2);
+ let gray_out = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution,
+ original, weights_layer2);
+ // Denormalize from [-1,1] to [0,1]
+ let gray_01 = (gray_out + 1.0) * 0.5;
+ result = vec4<f32>(gray_01, gray_01, gray_01, 1.0); // Expand to RGB
}
else {
result = input;
}
- return mix(original, result, params.blend_amount);
+ // Blend with ORIGINAL input from layer 0
+return original;
+// return mix(original, result, params.blend_amount);
}
diff --git a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
index 6052ac5..6ec78c1 100644
--- a/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
+++ b/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
@@ -1,185 +1,157 @@
// Auto-generated CNN weights
// DO NOT EDIT - Generated by train_cnn.py
-const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- -0.181929, -0.244329, -0.354404, 0.0,
- -0.291597, -0.195653, 0.081896, 0.0,
- 0.081595, 0.164081, -0.236318, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 0.731888, 0.717648, 0.524081, 0.0,
- -0.029760, -0.208000, 0.008438, 0.0,
- 0.442082, 0.354681, 0.049288, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.623141, -0.695759, -0.087885, 0.0,
- 0.043135, 0.071979, 0.213065, 0.0,
- 0.011581, 0.110995, 0.034100, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 0.170016, 0.188298, 0.134083, 0.0,
- -0.222954, -0.088011, 0.015668, 0.0,
- 0.921836, 0.437158, 0.061577, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 1.431940, 1.148113, 1.238067, 0.0,
- -0.212535, 0.366860, 0.320956, 0.0,
- 0.771192, 0.765570, 0.029189, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 0.171088, 0.000155, 0.212552, 0.0,
- 0.029536, 0.447892, 0.041381, 0.0,
- 0.011807, -0.167281, -0.200702, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.668151, -0.813927, -0.132108, 0.0,
- -0.156250, 0.179112, -0.069585, 0.0,
- 0.403347, 0.482877, 0.182611, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.609871, -0.768480, -0.590538, 0.0,
- -0.171854, 0.150167, 0.105694, 0.0,
- -0.059052, 0.066999, -0.244222, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.112983, -0.066299, 0.117696, 0.0,
- -0.172541, 0.095008, -0.160754, 0.0,
- -0.369667, -0.000628, 0.163602, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- )
+const weights_layer0: array<array<f32, 8>, 36> = array(
+ array<f32, 8>(0.181071, 0.180671, -0.156688, -0.133398, -0.121571, -0.069398, 0.197448, -0.010528),
+ array<f32, 8>(-0.315179, -0.222857, -0.084049, 0.211023, 0.093378, 0.041467, -0.171730, 0.109394),
+ array<f32, 8>(-0.776134, -0.748738, -0.426878, 0.038600, 0.055702, 0.264617, -0.766478, -0.140384),
+ array<f32, 8>(0.325726, 0.142498, 0.022606, -0.071913, -0.085795, -0.008587, 0.142883, 0.080188),
+ array<f32, 8>(0.036004, 0.238142, 0.248112, -0.031285, -0.124107, 0.040947, 0.298797, -0.010528),
+ array<f32, 8>(0.223016, 0.358479, -0.011029, 0.102233, -0.167285, 0.111095, 0.247539, 0.109394),
+ array<f32, 8>(1.147701, 1.132161, 0.406157, 0.030399, 0.155895, 0.158795, 1.070009, -0.140384),
+ array<f32, 8>(0.343120, 0.413393, 0.345993, 0.048482, 0.007208, 0.103310, 0.410492, 0.080188),
+ array<f32, 8>(-0.074602, -0.271241, -0.181283, -0.062921, 0.126872, -0.006445, -0.171823, -0.010528),
+ array<f32, 8>(-0.554038, -0.443533, -0.282759, 0.123052, -0.076625, 0.047519, -0.525091, 0.109394),
+ array<f32, 8>(-0.036767, 0.183663, 0.026803, -0.102403, -0.052118, 0.014384, 0.272146, -0.140384),
+ array<f32, 8>(0.006535, -0.016445, -0.063138, 0.121560, 0.026330, 0.142604, -0.172336, 0.080188),
+ array<f32, 8>(-0.166300, -0.154177, 0.124821, -0.109807, -0.052216, 0.009249, -0.217568, -0.010528),
+ array<f32, 8>(0.258850, 0.122862, -0.085148, 0.217309, 0.115453, -0.045434, -0.018513, 0.109394),
+ array<f32, 8>(0.030132, 0.150035, 0.140258, -0.016209, -0.002259, -0.073409, 0.106627, -0.140384),
+ array<f32, 8>(-0.382104, -0.406515, -0.130833, 0.006075, -0.032734, -0.066618, -0.293135, 0.080188),
+ array<f32, 8>(-0.344131, -0.804046, -0.131275, -0.025023, 0.107656, -0.105495, -0.477465, -0.010528),
+ array<f32, 8>(2.238629, 2.666346, 1.344841, 0.188894, -0.064351, -0.079447, 2.525197, 0.109394),
+ array<f32, 8>(1.617750, 1.332233, 0.817712, 0.084650, 0.008433, -0.191120, 1.499900, -0.140384),
+ array<f32, 8>(-1.007007, -0.971186, -0.670325, 0.066275, 0.044722, -0.034290, -0.874401, 0.080188),
+ array<f32, 8>(0.214410, 0.423019, 0.210104, -0.111646, 0.025179, 0.021650, 0.370431, -0.010528),
+ array<f32, 8>(0.698145, 0.742521, 0.259904, 0.172190, 0.002103, 0.150391, 0.710542, 0.109394),
+ array<f32, 8>(-0.437137, -0.378108, -0.350803, -0.151723, -0.079989, -0.014077, -0.393804, -0.140384),
+ array<f32, 8>(0.148036, 0.149877, 0.208686, 0.118023, 0.106081, -0.156695, 0.040117, 0.080188),
+ array<f32, 8>(0.007790, -0.006625, -0.036656, 0.033389, -0.087019, 0.037006, 0.125953, -0.010528),
+ array<f32, 8>(-0.362541, -0.324086, -0.150862, 0.000339, 0.227779, -0.035100, -0.325073, 0.109394),
+ array<f32, 8>(-0.248297, -0.083314, -0.071273, -0.108375, 0.037361, -0.055978, -0.192155, -0.140384),
+ array<f32, 8>(0.396704, 0.247469, 0.274000, -0.005654, 0.014277, 0.097346, 0.185350, 0.080188),
+ array<f32, 8>(-0.264719, -0.327878, -0.137154, -0.082728, 0.053555, 0.134324, -0.360542, -0.010528),
+ array<f32, 8>(0.154692, 0.521805, -0.014565, 0.236786, -0.050656, -0.077934, 0.252781, 0.109394),
+ array<f32, 8>(-0.762319, -0.789229, -0.565690, -0.050543, -0.033753, 0.065576, -0.704560, -0.140384),
+ array<f32, 8>(-0.366925, -0.189991, -0.177319, -0.068166, -0.096794, 0.111543, -0.319584, 0.080188),
+ array<f32, 8>(-0.069582, -0.007117, -0.080914, -0.054291, 0.099181, -0.057500, 0.064654, -0.010528),
+ array<f32, 8>(-0.234008, -0.283279, -0.317393, 0.138485, -0.089743, -0.116539, -0.340335, 0.109394),
+ array<f32, 8>(-0.332349, -0.045645, -0.142621, 0.066072, -0.063394, -0.047263, -0.275151, -0.140384),
+ array<f32, 8>(0.243287, -0.037305, 0.084462, 0.050442, 0.038222, -0.094055, 0.019547, 0.080188)
);
-const coord_weights_layer0 = mat2x4<f32>(
- 0.059076, -0.026617, -0.005155, 0.0,
- 0.135407, -0.090329, 0.058216, 0.0
+const weights_layer1: array<array<f32, 8>, 100> = array(
+ array<f32, 8>(0.036741, 0.174566, 0.291795, -0.033066, -0.084725, 0.017211, -0.119127, 0.705181),
+ array<f32, 8>(0.122186, -0.078751, 0.003745, 0.000022, 0.027259, -0.014794, -0.132560, -0.554858),
+ array<f32, 8>(0.275572, 0.018150, -0.008554, -0.146253, -0.124261, -0.106076, -0.124406, -0.384431),
+ array<f32, 8>(0.235391, -0.078051, 0.052366, 0.054055, 0.116088, 0.137109, -0.041252, 0.096202),
+ array<f32, 8>(0.029752, 0.005636, -0.434780, 0.218587, -0.174135, 0.077972, -0.243145, 0.705181),
+ array<f32, 8>(0.075832, 0.036187, -0.151527, -0.131852, 0.042588, -0.087969, 0.116469, -0.554858),
+ array<f32, 8>(0.234748, -0.130623, 0.049978, -0.065811, 0.114875, -0.159655, -0.065553, -0.384431),
+ array<f32, 8>(0.000944, -0.056210, 0.007023, -0.046152, 0.006291, 0.055730, 0.079512, 0.096202),
+ array<f32, 8>(-0.656454, 0.171425, 0.518928, -0.355405, 0.005253, 0.188539, 0.031148, 0.705181),
+ array<f32, 8>(0.253326, -0.090369, -0.096764, -0.047542, 0.073904, 0.040503, 0.111922, -0.554858),
+ array<f32, 8>(0.228924, 0.111527, -0.027116, -0.133119, 0.099857, -0.150050, -0.075578, -0.384431),
+ array<f32, 8>(-0.037827, 0.075674, -0.111529, -0.114302, -0.118413, -0.000674, 0.072307, 0.096202),
+ array<f32, 8>(-0.633403, -0.158435, 0.042044, -0.199189, 0.135613, 0.128976, -0.095854, 0.705181),
+ array<f32, 8>(0.224475, 0.009348, -0.014853, -0.083097, -0.129013, 0.058030, -0.010732, -0.554858),
+ array<f32, 8>(0.277777, -0.008638, -0.024935, 0.055844, -0.137042, -0.173785, -0.168956, -0.384431),
+ array<f32, 8>(-0.027801, 0.013278, -0.084694, -0.074887, 0.065917, 0.209263, 0.102990, 0.096202),
+ array<f32, 8>(-0.236525, 0.081545, 0.120326, 0.497596, -0.133167, 0.111884, -0.149371, 0.705181),
+ array<f32, 8>(0.041858, -0.011563, 0.027560, 0.003061, -0.033285, -0.083021, -0.297831, -0.554858),
+ array<f32, 8>(0.297747, 0.001253, 0.106780, 0.007335, 0.259682, -0.229853, -0.001358, -0.384431),
+ array<f32, 8>(-0.043069, -0.045672, -0.018898, -0.190785, -0.000672, 0.089811, 0.118545, 0.096202),
+ array<f32, 8>(-0.066188, 0.120479, -0.198253, -0.004995, 0.017530, -0.258193, -0.046440, 0.705181),
+ array<f32, 8>(-0.017375, 0.127093, -0.003759, 0.053924, 0.036127, 0.462423, 0.238701, -0.554858),
+ array<f32, 8>(0.281321, -0.034091, 0.129930, 0.026051, -0.134687, -0.102370, 0.015789, -0.384431),
+ array<f32, 8>(-0.098730, 0.061805, 0.049945, 0.045479, -0.037703, 0.164979, 0.015898, 0.096202),
+ array<f32, 8>(-0.782371, -0.398567, 0.203990, -0.918743, -0.150142, -0.208263, -0.191751, 0.705181),
+ array<f32, 8>(0.020808, -0.079593, -0.020436, 0.289221, 0.055709, 0.460382, 0.367026, -0.554858),
+ array<f32, 8>(0.349494, -0.057511, 0.075111, -0.102943, 0.063262, -0.130990, -0.008248, -0.384431),
+ array<f32, 8>(0.168912, 0.176731, -0.134468, 0.232372, -0.033147, 0.085774, -0.048677, 0.096202),
+ array<f32, 8>(0.618180, -0.520695, -0.530039, 1.196123, 0.122213, -0.173562, -0.324433, 0.705181),
+ array<f32, 8>(0.638450, -0.149078, -0.017734, 0.493464, -0.015998, 0.584648, 0.288746, -0.554858),
+ array<f32, 8>(0.074488, 0.212364, 0.187959, -0.055154, 0.065607, -0.068925, -0.017994, -0.384431),
+ array<f32, 8>(-0.018923, -0.099117, -0.005060, 0.242744, -0.274629, 0.051926, 0.079712, 0.096202),
+ array<f32, 8>(0.658693, -0.039718, 0.441997, 0.250902, 0.016350, -0.200507, 0.022444, 0.705181),
+ array<f32, 8>(0.559097, -0.196625, 0.035123, 0.046925, -0.186840, 0.619048, 0.559590, -0.554858),
+ array<f32, 8>(-0.139655, -0.021050, 0.295279, -0.302983, -0.061179, -0.146221, 0.082147, -0.384431),
+ array<f32, 8>(0.174535, -0.052728, -0.181830, -0.037692, -0.027964, 0.094183, 0.075072, 0.096202),
+ array<f32, 8>(-0.025873, -0.035599, 0.024004, 0.191656, -0.102396, -0.254495, -0.094824, 0.705181),
+ array<f32, 8>(0.228900, -0.003296, -0.008463, 0.324218, -0.049343, 0.447783, 0.273508, -0.554858),
+ array<f32, 8>(0.160779, -0.008924, -0.081788, -0.433376, 0.227781, -0.062434, 0.032769, -0.384431),
+ array<f32, 8>(0.278581, -0.029578, -0.048828, 0.037354, -0.083503, 0.051478, 0.129985, 0.096202),
+ array<f32, 8>(-0.179409, -0.150093, -0.154602, 0.155942, -0.030033, -0.052753, 0.126435, 0.705181),
+ array<f32, 8>(-0.276681, 0.000117, 0.005694, 0.112600, 0.053831, -0.226162, 0.238592, -0.554858),
+ array<f32, 8>(0.509299, -0.073816, -0.036746, -0.229020, -0.084083, 0.102554, 0.091619, -0.384431),
+ array<f32, 8>(-0.226426, 0.100353, 0.041059, 0.298785, 0.189457, -0.066325, -0.044970, 0.096202),
+ array<f32, 8>(-1.231526, -1.330216, 0.526922, -0.587336, -0.130186, -0.092712, 0.001462, 0.705181),
+ array<f32, 8>(-0.181467, -0.162607, -0.034778, 0.462190, 0.018091, -0.425409, -0.281373, -0.554858),
+ array<f32, 8>(-0.127841, 0.636898, -0.133350, -0.021139, 0.041704, 0.131472, 0.064884, -0.384431),
+ array<f32, 8>(-0.142875, 0.037216, -0.253924, 0.065928, 0.010346, -0.066305, 0.010009, 0.096202),
+ array<f32, 8>(1.275860, -1.926775, -2.786324, 0.841585, 0.184752, 0.129195, -0.746670, 0.705181),
+ array<f32, 8>(0.224131, -0.438348, -0.301076, 0.516342, -0.025436, -0.288755, 0.113458, -0.554858),
+ array<f32, 8>(-0.458493, 1.383449, -0.836509, 0.513842, 0.009788, 0.140230, -0.178303, -0.384431),
+ array<f32, 8>(0.615726, -0.780623, -0.056982, 0.279193, -0.126452, -0.148191, -0.142928, 0.096202),
+ array<f32, 8>(0.900523, -0.634242, 0.387372, 0.906412, 0.123193, 0.017453, 0.044746, 0.705181),
+ array<f32, 8>(0.027988, -0.101389, 0.206885, 0.397388, -0.107841, -0.200823, 0.062329, -0.554858),
+ array<f32, 8>(-0.075820, 0.221546, -0.360333, 0.653466, -0.164603, -0.018184, -0.054492, -0.384431),
+ array<f32, 8>(0.428530, -0.333444, 0.292398, -0.419585, 0.189928, 0.033952, 0.048148, 0.096202),
+ array<f32, 8>(-0.375948, -0.074633, 0.091707, 0.403029, -0.058652, 0.015809, -0.074391, 0.705181),
+ array<f32, 8>(0.257007, -0.003856, 0.044474, 0.471313, -0.028859, -0.319677, 0.420764, -0.554858),
+ array<f32, 8>(0.134231, -0.048700, 0.029223, -0.006709, 0.116888, 0.081077, 0.034694, -0.384431),
+ array<f32, 8>(-0.081323, 0.073446, -0.045513, 0.067464, 0.129549, -0.099304, 0.013744, 0.096202),
+ array<f32, 8>(-0.050969, -0.016310, 0.185193, 0.154125, -0.036880, 0.023559, 0.025859, 0.705181),
+ array<f32, 8>(0.126071, -0.051684, 0.043923, 0.108173, 0.066038, -0.192287, 0.352096, -0.554858),
+ array<f32, 8>(0.427852, 0.008965, 0.102399, -0.108640, -0.150961, 0.163040, -0.050800, -0.384431),
+ array<f32, 8>(-0.072433, 0.037114, 0.013594, 0.069869, 0.104687, -0.078728, -0.069972, 0.096202),
+ array<f32, 8>(0.361002, -0.392113, 0.274024, -0.034550, -0.082394, 0.095107, 0.401706, 0.705181),
+ array<f32, 8>(-0.348729, 0.008971, -0.079949, -0.066546, 0.134612, -0.303610, -0.186502, -0.554858),
+ array<f32, 8>(0.345962, 0.668520, -0.784279, 0.051949, -0.042057, 0.050625, 0.020689, -0.384431),
+ array<f32, 8>(-0.274767, 0.076023, -0.254009, 0.323161, 0.033120, -0.065693, 0.092788, 0.096202),
+ array<f32, 8>(-0.922054, -0.756431, 0.393524, -1.602483, 0.114452, 0.112853, 0.537923, 0.705181),
+ array<f32, 8>(-0.171843, 0.212152, -0.122848, 0.602091, 0.046238, -0.188875, -0.437717, -0.554858),
+ array<f32, 8>(-0.588119, 1.183603, -0.530633, -0.829767, -0.006702, 0.078743, 0.153854, -0.384431),
+ array<f32, 8>(-0.188030, -0.172885, -0.066943, -0.046181, -0.106921, -0.134572, 0.080975, 0.096202),
+ array<f32, 8>(-1.695124, 0.212725, 1.229505, -0.741534, 0.126303, 0.133647, 0.108956, 0.705181),
+ array<f32, 8>(-0.320230, -0.041878, -0.084837, 0.395743, -0.002644, -0.220269, 0.159212, -0.554858),
+ array<f32, 8>(-0.321301, 0.390960, 0.261866, -0.413745, -0.124661, 0.036689, 0.011903, -0.384431),
+ array<f32, 8>(0.119845, -0.057012, 0.525753, -0.132904, 0.033117, 0.029606, 0.127306, 0.096202),
+ array<f32, 8>(-0.612274, 0.254915, 0.019614, 0.080070, 0.006519, 0.087940, -0.047548, 0.705181),
+ array<f32, 8>(0.051014, -0.017890, 0.233801, -0.168954, 0.079934, -0.283275, 0.332838, -0.554858),
+ array<f32, 8>(0.332870, -0.118202, 0.131280, 0.036194, 0.131048, 0.067103, 0.056030, -0.384431),
+ array<f32, 8>(-0.257300, 0.003565, -0.232427, 0.046336, -0.003665, -0.094940, 0.127497, 0.096202),
+ array<f32, 8>(-0.600203, 0.063626, -0.105803, 0.208178, -0.051062, -0.008572, -0.035231, 0.705181),
+ array<f32, 8>(-0.042070, -0.023061, -0.037054, -0.002751, 0.019568, 0.078162, -0.044403, -0.554858),
+ array<f32, 8>(-0.079765, 0.198893, 0.041831, 0.061495, -0.120350, 0.150103, -0.058056, -0.384431),
+ array<f32, 8>(0.139291, -0.054425, 0.002238, 0.016114, 0.057097, -0.116833, -0.081744, 0.096202),
+ array<f32, 8>(-0.738361, -0.374730, 0.359610, 0.291041, -0.072061, -0.041352, -0.189793, 0.705181),
+ array<f32, 8>(0.090555, -0.020056, -0.012717, 0.120644, -0.026834, -0.073251, 0.329796, -0.554858),
+ array<f32, 8>(0.367344, -0.049545, -0.421975, -0.066803, 0.069864, 0.155396, -0.118779, -0.384431),
+ array<f32, 8>(0.045037, 0.104654, -0.079043, 0.036356, 0.026634, 0.000503, 0.029131, 0.096202),
+ array<f32, 8>(-0.250248, -0.186821, 0.671967, 0.211267, 0.217392, -0.004840, -0.025819, 0.705181),
+ array<f32, 8>(-0.085968, -0.124825, -0.065762, 0.084599, -0.049381, -0.046614, 0.102548, -0.554858),
+ array<f32, 8>(0.691882, 0.082284, 0.786431, -0.352507, 0.007640, 0.130147, 0.079168, -0.384431),
+ array<f32, 8>(-0.089554, 0.033068, -0.013298, 0.223259, -0.209770, -0.050661, -0.024264, 0.096202),
+ array<f32, 8>(-0.214123, -0.168090, -1.084093, -0.058018, 0.160565, -0.012314, 0.023839, 0.705181),
+ array<f32, 8>(-0.078129, -0.058517, 0.103285, -0.001551, -0.050279, 0.056731, 0.340157, -0.554858),
+ array<f32, 8>(0.540997, 0.024887, -0.143659, -0.515214, -0.222923, 0.116850, 0.008144, -0.384431),
+ array<f32, 8>(-0.257782, 0.012741, 0.142405, 0.064579, 0.010678, -0.069590, -0.029995, 0.096202),
+ array<f32, 8>(-0.153197, -0.103723, 0.189375, -0.136753, 0.037238, -0.016964, 0.086707, 0.705181),
+ array<f32, 8>(0.168810, 0.040280, -0.134904, -0.028552, -0.059283, -0.131581, 0.522745, -0.554858),
+ array<f32, 8>(0.450779, 0.116955, -0.198275, -0.132313, 0.138815, 0.283598, 0.010077, -0.384431),
+ array<f32, 8>(-0.010634, -0.029003, -0.015055, 0.000537, -0.000948, -0.070859, 0.027198, 0.096202)
);
-const bias_layer0 = vec4<f32>(-0.526177, -0.569862, -1.370040, 0.0);
-
-const weights_layer1: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- 0.180029, -1.107249, 0.570741, 0.0,
- -0.098536, 0.079545, -0.083257, 0.0,
- -0.020066, 0.333084, 0.039506, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 3.068946, -1.783570, -0.550517, 0.0,
- -0.296369, -0.080958, 0.040260, 0.0,
- -0.093713, -0.212577, -0.110011, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 2.282564, -0.538192, -0.793214, 0.0,
- -0.395788, 0.130881, 0.078571, 0.0,
- -0.041375, 0.061666, 0.045651, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.267284, -1.971639, -0.099616, 0.0,
- -0.084432, 0.139794, 0.007091, 0.0,
- -0.103042, -0.104340, 0.067299, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -5.233469, -2.252747, -3.555217, 0.0,
- 0.647940, -0.178858, 0.351633, 0.0,
- -0.014237, -0.505881, 0.165940, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.121700, -0.677386, -2.435040, 0.0,
- 0.084806, -0.028000, 0.380387, 0.0,
- -0.020906, -0.279161, 0.041915, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 2.982562, -0.298441, -0.147775, 0.0,
- -0.291832, 0.102875, -0.128590, 0.0,
- -0.091786, 0.104389, -0.188678, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -4.434978, -0.261830, -2.436411, 0.0,
- 0.349188, -0.245908, 0.272592, 0.0,
- 0.010322, -0.148525, -0.031531, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- 0.129886, 1.516168, -0.755576, 0.0,
- 0.133138, -0.260276, 0.028059, 0.0,
- 0.001185, 0.141547, -0.003606, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- )
+const weights_layer2: array<array<f32, 8>, 9> = array(
+ array<f32, 8>(-0.049127, 0.044294, 0.110612, -0.009803, -0.027652, -0.198405, -0.024763, 0.092761),
+ array<f32, 8>(-0.104378, 0.082831, 0.347193, -0.040960, -0.067023, -0.078163, 0.003278, 0.092761),
+ array<f32, 8>(-0.044334, 0.015010, 0.085171, -0.000658, -0.048960, -0.180710, -0.118388, 0.092761),
+ array<f32, 8>(-0.090802, 0.239246, 0.060319, -0.157119, -0.096438, 0.122192, -0.110975, 0.092761),
+ array<f32, 8>(-0.243697, 0.489006, 0.108902, -0.465329, 0.124279, 0.126061, 0.159364, 0.092761),
+ array<f32, 8>(-0.102790, 0.124908, 0.017353, -0.057910, 0.113450, 0.203798, 0.088674, 0.092761),
+ array<f32, 8>(-0.038260, 0.080576, 0.018707, -0.018687, -0.003492, 0.084815, 0.019988, 0.092761),
+ array<f32, 8>(-0.048685, 0.067695, 0.012956, 0.020401, -0.090458, -0.066818, -0.060523, 0.092761),
+ array<f32, 8>(-0.060803, 0.066513, 0.004983, -0.006159, 0.095386, -0.025576, 0.056029, 0.092761)
);
-const bias_layer1 = vec4<f32>(1.367986, -1.148709, -0.650040, 0.0);
-
-const weights_layer2: array<mat4x4<f32>, 9> = array(
- mat4x4<f32>(
- -0.137003, -0.289376, 0.625000, 0.0,
- -0.120120, -0.238968, 0.448432, 0.0,
- -0.142094, -0.253706, 0.458181, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.337017, -0.757585, 0.135953, 0.0,
- -0.304432, -0.553491, 0.419907, 0.0,
- -0.313585, -0.467667, 0.615326, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.161089, -0.328735, 0.612679, 0.0,
- -0.137144, -0.172882, 0.176362, 0.0,
- -0.153195, -0.061571, 0.173977, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.227814, -0.544193, -0.564658, 0.0,
- -0.211743, -0.430586, 0.080349, 0.0,
- -0.214442, -0.417501, 0.880266, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.435370, -0.295169, -0.865976, 0.0,
- -0.423147, -0.274780, 0.323049, 0.0,
- -0.411180, -0.062517, 1.099769, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.199573, -0.488030, -0.396440, 0.0,
- -0.187844, -0.360516, -0.156646, 0.0,
- -0.188681, -0.292304, -0.134645, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.123218, -0.287990, 0.154656, 0.0,
- -0.112954, -0.282778, 0.498742, 0.0,
- -0.139083, -0.319337, 1.112621, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.267477, -0.691374, -0.028960, 0.0,
- -0.246348, -0.585583, 0.401194, 0.0,
- -0.253279, -0.562875, 1.105818, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- ),
- mat4x4<f32>(
- -0.083133, -0.131627, 0.460039, 0.0,
- -0.071126, -0.108601, 0.163545, 0.0,
- -0.092579, -0.110020, 0.131282, 0.0,
- 0.0, 0.0, 0.0, 0.0,
- )
-);
-
-const bias_layer2 = vec4<f32>(-1.805686, -0.798340, 0.462318, 0.0);
-