summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn_v2_layer_1.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'workspaces/main/shaders/cnn_v2_layer_1.wgsl')
-rw-r--r--workspaces/main/shaders/cnn_v2_layer_1.wgsl174
1 files changed, 174 insertions, 0 deletions
diff --git a/workspaces/main/shaders/cnn_v2_layer_1.wgsl b/workspaces/main/shaders/cnn_v2_layer_1.wgsl
new file mode 100644
index 0000000..f490d13
--- /dev/null
+++ b/workspaces/main/shaders/cnn_v2_layer_1.wgsl
@@ -0,0 +1,174 @@
+// CNN v2 Layer 1 - Auto-generated
+// Kernel: 3×3, In: 16, Out: 4
+
+const KERNEL_SIZE: u32 = 3u;
+const IN_CHANNELS: u32 = 16u;
+const OUT_CHANNELS: u32 = 4u;
+const KERNEL_RADIUS: i32 = 1;
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, 576> = array(
+ 0.337402, 0.638672, -0.481201, 0.699707, 1.127930, -0.018280, -0.062195, 0.148682,
+ -0.655273, 0.448975, 0.969238, -0.280762, 0.817383, 1.271484, 0.421387, -0.163696,
+ 0.305664, -0.454834, 0.354004, 0.932617, -0.411377, 0.581543, 1.263672, 0.422363,
+ -0.380371, 0.152588, -0.668945, -0.063782, 0.060730, 0.022018, -0.075195, -0.049286,
+ 0.068542, 0.057343, -0.009773, 0.006344, -0.080872, -0.179932, -0.297119, 0.098328,
+ 0.061951, -0.088989, 0.047913, 0.093628, -0.091858, -0.068298, 0.102600, -0.044067,
+ -0.054230, -0.031799, 0.050934, -0.300049, -0.202637, -0.203613, -0.294189, -0.361084,
+ 0.277344, -0.213257, -0.239624, 0.193237, -0.215210, -0.295166, 0.298828, -0.065369,
+ 0.148926, 0.024963, 0.272705, 0.368164, 0.173096, 0.061279, 0.291260, 0.151611,
+ 0.411133, 0.216431, -0.179932, 0.506348, 0.319580, 0.059875, -0.134399, -0.150635,
+ -0.275391, 0.029480, 0.115417, 0.063782, 0.018723, -0.073364, -0.019653, 0.066467,
+ -0.086731, 0.113220, 0.110535, 0.011940, -0.094727, 0.262207, 0.180298, 0.141357,
+ 0.249634, 0.199585, 0.120605, 0.403809, 0.242676, -0.028442, 0.251953, 0.130737,
+ 0.152832, -0.306396, -0.324951, -0.176514, 0.161133, 0.333252, -0.195068, 0.250244,
+ 0.569824, 0.011223, -0.186035, 0.048279, -0.325439, 0.272217, 0.144043, -0.142700,
+ 0.447754, 0.434082, 0.124878, -0.157471, -0.120422, -0.281494, 0.338135, 0.266113,
+ -0.301514, 0.424805, 0.541504, -0.195679, 0.054962, 0.061798, -0.323975, 0.056732,
+ 0.072571, -0.087341, 0.052856, -0.057220, 0.023270, 0.071472, 0.014038, 0.083008,
+ -0.050659, 0.020111, 0.035614, -0.038086, -0.042786, 0.060242, -0.050079, -0.044403,
+ -0.059631, 0.075500, 0.056000, 0.010910, -0.064026, -0.016037, -0.050720, 0.050171,
+ -0.075256, -0.014183, 0.047058, -0.086731, 0.027939, 0.063232, -0.024597, -0.039551,
+ 0.000622, -0.048370, -0.001906, 0.058868, -0.074524, 0.019714, -0.036011, 0.028442,
+ 0.009766, -0.060577, -0.007416, -0.014381, 0.002317, -0.023483, 0.014313, 0.057434,
+ 0.063110, 0.030350, -0.027557, 0.023270, 0.055115, -0.003502, 0.012268, -0.054993,
+ -0.084961, -0.022736, 0.076233, 0.027573, -0.068787, -0.036987, -0.018539, -0.049347,
+ 0.032227, 0.033081, 0.050476, 0.043030, 0.023636, -0.039764, -0.018600, 0.073669,
+ 0.032166, -0.047119, -0.033325, -0.038605, 0.034119, -0.076843, 0.005863, -0.049103,
+ 0.065796, -0.056458, 0.054504, -0.008354, -0.018509, -0.057739, -0.075684, -0.053680,
+ 0.036804, 0.020721, -0.056183, 0.021774, -0.043884, 0.033661, -0.029633, 0.027374,
+ -0.087891, 0.030853, -0.040070, 0.013733, -0.082275, -0.072571, -0.055756, 0.002262,
+ 0.004421, -0.012169, -0.078064, -0.063904, -0.051758, -0.033264, -0.059265, -0.062256,
+ 0.063782, -0.088745, -0.026855, 0.062805, -0.036591, 0.037659, -0.012970, 0.025513,
+ -0.000908, 0.027084, 0.001842, -0.080750, -0.049713, -0.069397, -0.046448, -0.031006,
+ 0.012543, 0.009369, -0.080139, -0.034363, 0.003361, -0.052704, 0.041870, 0.059265,
+ 0.029938, 0.000138, 0.049896, 0.068787, 0.040405, -0.073608, 0.047668, 0.015320,
+ -0.033203, -0.016983, 0.034149, -0.010323, 0.029877, 0.078003, -0.054688, -0.021805,
+ -0.019409, 0.010284, 0.089172, -0.050385, 0.024857, -0.041992, 0.016602, 0.082397,
+ 0.081970, 0.096375, 0.060760, -0.006603, 0.029907, 0.012131, 0.104980, 0.034210,
+ 0.074707, -0.028320, -0.020248, 0.114868, -0.036957, 0.040192, 0.002888, 0.034973,
+ -0.038635, -0.018204, -0.058563, 0.029419, 0.013344, 0.027618, 0.073669, -0.038361,
+ 0.080933, 0.044586, -0.013214, 0.022675, 0.084351, 0.081848, 0.027328, 0.043915,
+ 0.040771, 0.078918, 0.054443, -0.049652, 0.073547, 0.103882, 0.065918, 0.070923,
+ -0.037476, -0.011215, -0.021408, 0.094727, 0.042450, 0.032806, -0.064026, 0.023941,
+ 0.011780, 0.041260, -0.038818, 0.079163, 0.079468, 0.053680, 0.047150, 0.003571,
+ 0.054840, 0.045929, -0.041382, -0.033539, 0.069153, 0.046234, 0.119263, -0.006340,
+ -0.050323, 0.030212, 0.069092, 0.045441, 0.096313, -0.024628, -0.088745, 0.009033,
+ -0.016830, 0.028534, -0.042755, -0.031921, 0.013611, -0.029251, -0.051483, -0.005848,
+ -0.032837, -0.058136, 0.075989, -0.008125, 0.108765, -0.004745, -0.003422, 0.079590,
+ 0.090515, -0.019196, -0.006786, 0.059479, -0.041168, 0.093445, 0.075439, -0.025055,
+ 0.067139, 0.011734, 0.031586, 0.029587, 0.098267, 0.025848, 0.095276, 0.003189,
+ 0.105408, 0.018799, -0.102478, 0.033813, 0.004272, 0.020477, 0.033142, 0.009727,
+ -0.021393, 0.120300, 0.088684, -0.037842, -0.094177, 0.017944, 0.020126, -0.002304,
+ -0.016006, 0.018112, 0.072693, -0.072021, -0.171265, -0.053528, -0.093201, 0.024124,
+ -0.050476, -0.023422, -0.071167, 0.046478, 0.034607, 0.076904, 0.013077, -0.082031,
+ 0.091858, -0.001575, 0.083801, 0.078003, 0.019119, -0.004967, 0.027298, 0.027740,
+ 0.032623, 0.048370, 0.029099, 0.093201, 0.049957, -0.007191, 0.059631, 0.008659,
+ 0.042725, -0.009369, 0.089417, 0.074951, -0.024704, 0.005344, 0.123840, 0.080322,
+ 0.096375, 0.070312, -0.010399, 0.033203, -0.009743, -0.030045, -0.039520, 0.042023,
+ -0.017441, 0.073486, 0.049500, -0.039734, 0.009811, 0.093262, -0.069641, 0.099365,
+ -0.010414, 0.048859, 0.099182, -0.007256, -0.023941, -0.021393, -0.005703, 0.025055,
+ 0.054535, 0.093384, -0.033661, 0.073242, 0.055023, 0.037170, -0.009300, 0.048615,
+ 0.019150, 0.019409, -0.080688, -0.050049, 0.104126, -0.023193, 0.044708, 0.111816,
+ 0.061584, 0.042755, -0.013863, -0.008385, -0.039703, 0.070618, -0.016922, -0.040833,
+ 0.051178, -0.060333, -0.004368, -0.009827, 0.051544, 0.072083, 0.068176, 0.148071,
+ 0.159424, 0.017578, 0.089905, -0.006794, 0.066101, -0.051117, 0.088684, -0.002989,
+ -0.066895, 0.089844, 0.012131, -0.020203, 0.011230, 0.000327, 0.073669, 0.060669,
+ 0.091064, 0.075989, 0.051971, 0.045044, 0.033875, 0.040466, -0.029449, 0.128418,
+ -0.000229, -0.026901, 0.052063, 0.000995, -0.032532, 0.105896, -0.001241, 0.114075,
+ 0.047607, 0.090332, 0.063660, 0.016495, 0.124817, 0.090942, 0.021545, 0.007164,
+ 0.074890, 0.118347, 0.047394, 0.052856, 0.104980, 0.009384, 0.034363, 0.019073,
+ 0.072388, -0.013313, 0.119141, 0.021255, 0.103210, 0.058319, 0.186035, -0.010818,
+ 0.037109, -0.044037, -0.075989, -0.001281, 0.017899, 0.030701, -0.080261, 0.082703
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }
+ }
+ }
+
+ output[c] = max(0.0, sum); // ReLU
+ }
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}