diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 12:17:59 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 12:17:59 +0100 |
| commit | ff4c1213636e66d4457a95cad12300c58e8d6781 (patch) | |
| tree | b47a9fc5c860c4eff39054b2ffc248ffbe19fa10 /workspaces/main | |
| parent | eaf0bd855306e70ca03f2d6579b4d6551aff6482 (diff) | |
Refine training script output and validation
1. Loss printed at every epoch with \r (no scrolling)
2. Validation only on final epoch (not all checkpoints)
3. Process all input images (not just img_000.png)
Training output now shows live progress with single line update.
Diffstat (limited to 'workspaces/main')
| -rw-r--r-- | workspaces/main/shaders/cnn_v2_layer_0.wgsl | 174 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2_layer_1.wgsl | 174 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2_layer_2.wgsl | 156 |
3 files changed, 504 insertions, 0 deletions
diff --git a/workspaces/main/shaders/cnn_v2_layer_0.wgsl b/workspaces/main/shaders/cnn_v2_layer_0.wgsl new file mode 100644 index 0000000..8e14957 --- /dev/null +++ b/workspaces/main/shaders/cnn_v2_layer_0.wgsl @@ -0,0 +1,174 @@ +// CNN v2 Layer 0 - Auto-generated +// Kernel: 3×3, In: 8, Out: 8 + +const KERNEL_SIZE: u32 = 3u; +const IN_CHANNELS: u32 = 8u; +const OUT_CHANNELS: u32 = 8u; +const KERNEL_RADIUS: i32 = 1; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, 576> = array( + 0.057281, -0.041962, 0.003933, 0.026459, 0.304199, 0.067261, 0.191895, 0.047455, + 0.074402, 0.201660, 0.158325, 0.150513, 0.219238, 0.260010, 0.319336, 0.208618, + 0.050201, 0.090210, 0.086853, 0.181152, 0.060486, 0.167847, 0.161499, 0.265869, + 0.163818, 0.100647, 0.243408, -0.008553, -0.010849, 0.046509, -0.060608, -0.022263, + 0.094360, -0.043854, -0.005329, -0.093262, 0.032349, 0.007259, 0.039948, -0.018692, + -0.000618, 0.052368, -0.038055, 0.118042, -0.084595, 0.044281, -0.107056, 0.089478, + -0.076477, 0.017441, 0.088135, 0.076721, -0.063965, 0.001612, 0.062469, 0.067505, + 0.035736, 0.115051, -0.117737, -0.076843, -0.008888, -0.002028, -0.061005, 0.081726, + 0.115051, -0.028183, 0.043213, -0.079285, -0.040314, -0.047699, -0.051575, -0.052521, + 0.071533, 0.084656, 0.051910, 0.090637, -0.104248, -0.066467, -0.032104, -0.006977, + 0.075439, -0.004841, 0.084656, -0.034698, 0.035675, -0.101929, -0.035034, -0.036804, + 0.069641, -0.026840, -0.017807, -0.088318, -0.125000, -0.042847, -0.003063, 0.007622, + 0.076416, 0.094971, -0.019058, 0.083496, -0.085205, 0.036285, -0.077209, 0.082458, + 0.056549, 0.038818, 0.092224, -0.002499, 0.069641, 0.097229, 0.069275, -0.111084, + -0.092041, -0.020462, -0.061279, -0.032196, -0.088623, 0.032227, -0.117004, -0.125854, + -0.015884, 0.093018, -0.070923, -0.117615, -0.081848, -0.115479, 0.033508, -0.026443, + -0.009850, -0.063232, 0.098328, -0.000984, 0.039886, -0.085754, -0.108826, 0.030258, + 0.091675, 0.024384, -0.118958, -0.077148, -0.122437, -0.002090, -0.089539, 0.096741, + 0.095337, 0.108582, -0.101807, 0.152222, 0.206177, 0.050323, -0.111450, -0.104431, + -0.037445, 0.276611, 0.244019, 0.171143, 0.131592, 0.056030, 0.141602, 0.014267, + -0.025955, -0.019730, 0.155884, 0.072144, 0.176636, -0.010117, 0.141724, 0.103027, + -0.253174, -0.229370, -0.105713, -0.005898, 0.075439, -0.002014, -0.010506, -0.108093, + -0.016724, 0.108215, 0.053589, -0.044586, 0.030396, -0.077759, 0.058594, -0.018463, + 0.027100, 0.030823, -0.026947, -0.014084, 0.121643, 0.116638, -0.010239, 0.106262, + -0.109070, -0.044281, -0.045319, -0.021942, 0.083923, 0.114929, 0.154541, 0.078186, + -0.047394, 0.007957, 0.099182, -0.030075, 0.103699, 0.080994, -0.085144, 0.047180, + 0.099792, 0.081116, 0.084961, 0.151123, 0.000963, 0.029221, 0.073181, 0.086609, + 0.149048, -0.052185, -0.158936, 0.146240, 0.020004, 0.063110, 0.111877, 0.037201, + 0.087585, 0.134277, 0.058258, -0.075256, 0.141357, 0.045776, 0.171753, 0.186035, + 0.093201, 0.202637, 0.018723, -0.047638, 0.072510, 0.132812, 0.182251, 0.191650, + 0.163818, 0.146362, 0.124451, -0.082214, 0.094482, -0.007275, 0.029099, -0.040314, + -0.017624, -0.018860, -0.108398, -0.111145, 0.058289, -0.106995, -0.091919, 0.069824, + -0.084045, -0.105957, 0.065002, -0.012894, 0.042297, -0.081299, -0.112976, 0.012314, + 0.015625, -0.100708, -0.039673, 0.092041, 0.037201, 0.089722, 0.064087, 0.000403, + 0.120667, -0.012238, -0.055695, 0.010620, -0.022110, -0.008751, 0.038605, 0.075256, + 0.041260, 0.128296, -0.072021, 0.020828, -0.072449, 0.051239, 0.034058, 0.122803, + -0.062103, 0.156006, -0.111633, 0.043671, 0.209229, 0.006088, 0.141968, 0.209961, + 0.122620, -0.004547, 0.107727, 0.115601, 0.003378, 0.375732, 0.068481, 0.037842, + 0.159546, -0.014450, 0.073425, 0.168701, -0.052643, 0.060699, 0.333740, 0.033905, + -0.060150, 0.053558, 0.165527, -0.052460, -0.047882, 0.080750, 0.110352, -0.057098, + 0.057983, -0.018692, 0.019714, -0.056427, -0.053314, -0.001763, 0.027039, 0.003395, + -0.131226, -0.068481, -0.086609, 0.065186, 0.084717, 0.036530, 0.043488, 0.013893, + -0.076660, 0.081177, 0.037476, -0.124084, -0.070312, -0.027130, -0.009331, -0.128174, + -0.075256, 0.098206, -0.046539, -0.045319, 0.083923, -0.050598, 0.063477, 0.007408, + 0.026794, -0.090454, -0.083435, 0.129761, 0.044556, 0.051849, 0.115662, 0.071167, + 0.004414, 0.048035, -0.148682, 0.098938, 0.200562, 0.111938, 0.208496, 0.200684, + -0.050262, 0.119568, 0.062988, 0.072083, 0.123779, 0.369629, 0.317627, 0.187622, + 0.157227, 0.183960, 0.031921, 0.142944, 0.080627, 0.218628, 0.264160, 0.156128, + 0.084961, 0.029343, 0.057617, 0.089233, 0.041138, 0.044373, 0.074707, 0.025818, + 0.113708, -0.045380, -0.114929, 0.104370, -0.012238, -0.174194, -0.169312, -0.070312, + -0.005863, 0.027481, 0.053345, -0.016006, -0.057953, -0.010284, 0.034241, -0.041077, + -0.002373, 0.034515, 0.078552, -0.066162, -0.035400, 0.072510, 0.060425, -0.037720, + -0.025955, 0.118042, -0.071777, 0.133667, 0.012192, -0.080933, 0.093445, 0.052826, + -0.037354, -0.052277, 0.124084, 0.029861, 0.137085, 0.053009, -0.034180, -0.011421, + 0.089233, 0.172729, 0.146118, 0.003944, 0.279541, 0.162842, 0.112244, 0.204956, + 0.059753, 0.117737, 0.330322, 0.185547, 0.194946, 0.404541, 0.274658, 0.177612, + 0.153320, 0.189575, 0.032257, 0.285400, 0.158203, 0.048035, 0.476562, 0.301025, + -0.179565, 0.160767, 0.137207, 0.102478, -0.060547, 0.060364, -0.091858, 0.064209, + 0.082642, 0.044769, -0.096436, -0.103699, -0.021683, 0.007221, -0.048737, 0.071228, + -0.069580, 0.066528, -0.122864, -0.008415, -0.094788, 0.040131, -0.091431, -0.029602, + -0.112488, -0.074158, -0.004898, -0.006721, -0.118286, -0.047516, 0.069519, 0.121521, + -0.004158, 0.167603, -0.092468, -0.049927, 0.006599, 0.097595, 0.064087, 0.083435, + 0.026993, 0.071411, 0.020538, 0.022293, 0.022858, 0.124268, 0.098999, -0.031738, + 0.019806, -0.087341, -0.096558, -0.099304, -0.113159, 0.021744, -0.080200, -0.056030, + 0.089661, -0.055115, -0.115845, -0.040222, 0.035919, 0.027832, 0.034668, 0.072632, + 0.071838, -0.081116, 0.050262, -0.037872, 0.054047, -0.096680, -0.102051, -0.044281, + 0.078796, -0.095154, -0.013229, 0.031555, -0.058533, -0.114441, -0.008530, 0.112732, + -0.057251, 0.096191, -0.008385, 0.052246, -0.016983, 0.092041, 0.013710, 0.012299, + -0.109497, 0.025604, -0.121643, -0.023819, 0.039490, -0.090088, -0.013145, -0.101562, + -0.115051, 0.050232, -0.047119, -0.055847, -0.017563, 0.103760, 0.116333, -0.061768, + -0.083069, -0.030319, 0.078003, -0.010124, 0.044617, -0.045868, 0.103638, 0.032379, + -0.093506, -0.048004, -0.022079, -0.004353, -0.048187, -0.025330, -0.070740, -0.014671 +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) { + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) { + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load input features + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + } + + // Accumulate: layer input channels (if layer_idx > 0) + let prev_channels = IN_CHANNELS - 8u; + for (var i: u32 = 0u; i < prev_channels; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + } + } + } + + output[c] = max(0.0, sum); // ReLU + } + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +} diff --git a/workspaces/main/shaders/cnn_v2_layer_1.wgsl b/workspaces/main/shaders/cnn_v2_layer_1.wgsl new file mode 100644 index 0000000..f490d13 --- /dev/null +++ b/workspaces/main/shaders/cnn_v2_layer_1.wgsl @@ -0,0 +1,174 @@ +// CNN v2 Layer 1 - Auto-generated +// Kernel: 3×3, In: 16, Out: 4 + +const KERNEL_SIZE: u32 = 3u; +const IN_CHANNELS: u32 = 16u; +const OUT_CHANNELS: u32 = 4u; +const KERNEL_RADIUS: i32 = 1; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, 576> = array( + 0.337402, 0.638672, -0.481201, 0.699707, 1.127930, -0.018280, -0.062195, 0.148682, + -0.655273, 0.448975, 0.969238, -0.280762, 0.817383, 1.271484, 0.421387, -0.163696, + 0.305664, -0.454834, 0.354004, 0.932617, -0.411377, 0.581543, 1.263672, 0.422363, + -0.380371, 0.152588, -0.668945, -0.063782, 0.060730, 0.022018, -0.075195, -0.049286, + 0.068542, 0.057343, -0.009773, 0.006344, -0.080872, -0.179932, -0.297119, 0.098328, + 0.061951, -0.088989, 0.047913, 0.093628, -0.091858, -0.068298, 0.102600, -0.044067, + -0.054230, -0.031799, 0.050934, -0.300049, -0.202637, -0.203613, -0.294189, -0.361084, + 0.277344, -0.213257, -0.239624, 0.193237, -0.215210, -0.295166, 0.298828, -0.065369, + 0.148926, 0.024963, 0.272705, 0.368164, 0.173096, 0.061279, 0.291260, 0.151611, + 0.411133, 0.216431, -0.179932, 0.506348, 0.319580, 0.059875, -0.134399, -0.150635, + -0.275391, 0.029480, 0.115417, 0.063782, 0.018723, -0.073364, -0.019653, 0.066467, + -0.086731, 0.113220, 0.110535, 0.011940, -0.094727, 0.262207, 0.180298, 0.141357, + 0.249634, 0.199585, 0.120605, 0.403809, 0.242676, -0.028442, 0.251953, 0.130737, + 0.152832, -0.306396, -0.324951, -0.176514, 0.161133, 0.333252, -0.195068, 0.250244, + 0.569824, 0.011223, -0.186035, 0.048279, -0.325439, 0.272217, 0.144043, -0.142700, + 0.447754, 0.434082, 0.124878, -0.157471, -0.120422, -0.281494, 0.338135, 0.266113, + -0.301514, 0.424805, 0.541504, -0.195679, 0.054962, 0.061798, -0.323975, 0.056732, + 0.072571, -0.087341, 0.052856, -0.057220, 0.023270, 0.071472, 0.014038, 0.083008, + -0.050659, 0.020111, 0.035614, -0.038086, -0.042786, 0.060242, -0.050079, -0.044403, + -0.059631, 0.075500, 0.056000, 0.010910, -0.064026, -0.016037, -0.050720, 0.050171, + -0.075256, -0.014183, 0.047058, -0.086731, 0.027939, 0.063232, -0.024597, -0.039551, + 0.000622, -0.048370, -0.001906, 0.058868, -0.074524, 0.019714, -0.036011, 0.028442, + 0.009766, -0.060577, -0.007416, -0.014381, 0.002317, -0.023483, 0.014313, 0.057434, + 0.063110, 0.030350, -0.027557, 0.023270, 0.055115, -0.003502, 0.012268, -0.054993, + -0.084961, -0.022736, 0.076233, 0.027573, -0.068787, -0.036987, -0.018539, -0.049347, + 0.032227, 0.033081, 0.050476, 0.043030, 0.023636, -0.039764, -0.018600, 0.073669, + 0.032166, -0.047119, -0.033325, -0.038605, 0.034119, -0.076843, 0.005863, -0.049103, + 0.065796, -0.056458, 0.054504, -0.008354, -0.018509, -0.057739, -0.075684, -0.053680, + 0.036804, 0.020721, -0.056183, 0.021774, -0.043884, 0.033661, -0.029633, 0.027374, + -0.087891, 0.030853, -0.040070, 0.013733, -0.082275, -0.072571, -0.055756, 0.002262, + 0.004421, -0.012169, -0.078064, -0.063904, -0.051758, -0.033264, -0.059265, -0.062256, + 0.063782, -0.088745, -0.026855, 0.062805, -0.036591, 0.037659, -0.012970, 0.025513, + -0.000908, 0.027084, 0.001842, -0.080750, -0.049713, -0.069397, -0.046448, -0.031006, + 0.012543, 0.009369, -0.080139, -0.034363, 0.003361, -0.052704, 0.041870, 0.059265, + 0.029938, 0.000138, 0.049896, 0.068787, 0.040405, -0.073608, 0.047668, 0.015320, + -0.033203, -0.016983, 0.034149, -0.010323, 0.029877, 0.078003, -0.054688, -0.021805, + -0.019409, 0.010284, 0.089172, -0.050385, 0.024857, -0.041992, 0.016602, 0.082397, + 0.081970, 0.096375, 0.060760, -0.006603, 0.029907, 0.012131, 0.104980, 0.034210, + 0.074707, -0.028320, -0.020248, 0.114868, -0.036957, 0.040192, 0.002888, 0.034973, + -0.038635, -0.018204, -0.058563, 0.029419, 0.013344, 0.027618, 0.073669, -0.038361, + 0.080933, 0.044586, -0.013214, 0.022675, 0.084351, 0.081848, 0.027328, 0.043915, + 0.040771, 0.078918, 0.054443, -0.049652, 0.073547, 0.103882, 0.065918, 0.070923, + -0.037476, -0.011215, -0.021408, 0.094727, 0.042450, 0.032806, -0.064026, 0.023941, + 0.011780, 0.041260, -0.038818, 0.079163, 0.079468, 0.053680, 0.047150, 0.003571, + 0.054840, 0.045929, -0.041382, -0.033539, 0.069153, 0.046234, 0.119263, -0.006340, + -0.050323, 0.030212, 0.069092, 0.045441, 0.096313, -0.024628, -0.088745, 0.009033, + -0.016830, 0.028534, -0.042755, -0.031921, 0.013611, -0.029251, -0.051483, -0.005848, + -0.032837, -0.058136, 0.075989, -0.008125, 0.108765, -0.004745, -0.003422, 0.079590, + 0.090515, -0.019196, -0.006786, 0.059479, -0.041168, 0.093445, 0.075439, -0.025055, + 0.067139, 0.011734, 0.031586, 0.029587, 0.098267, 0.025848, 0.095276, 0.003189, + 0.105408, 0.018799, -0.102478, 0.033813, 0.004272, 0.020477, 0.033142, 0.009727, + -0.021393, 0.120300, 0.088684, -0.037842, -0.094177, 0.017944, 0.020126, -0.002304, + -0.016006, 0.018112, 0.072693, -0.072021, -0.171265, -0.053528, -0.093201, 0.024124, + -0.050476, -0.023422, -0.071167, 0.046478, 0.034607, 0.076904, 0.013077, -0.082031, + 0.091858, -0.001575, 0.083801, 0.078003, 0.019119, -0.004967, 0.027298, 0.027740, + 0.032623, 0.048370, 0.029099, 0.093201, 0.049957, -0.007191, 0.059631, 0.008659, + 0.042725, -0.009369, 0.089417, 0.074951, -0.024704, 0.005344, 0.123840, 0.080322, + 0.096375, 0.070312, -0.010399, 0.033203, -0.009743, -0.030045, -0.039520, 0.042023, + -0.017441, 0.073486, 0.049500, -0.039734, 0.009811, 0.093262, -0.069641, 0.099365, + -0.010414, 0.048859, 0.099182, -0.007256, -0.023941, -0.021393, -0.005703, 0.025055, + 0.054535, 0.093384, -0.033661, 0.073242, 0.055023, 0.037170, -0.009300, 0.048615, + 0.019150, 0.019409, -0.080688, -0.050049, 0.104126, -0.023193, 0.044708, 0.111816, + 0.061584, 0.042755, -0.013863, -0.008385, -0.039703, 0.070618, -0.016922, -0.040833, + 0.051178, -0.060333, -0.004368, -0.009827, 0.051544, 0.072083, 0.068176, 0.148071, + 0.159424, 0.017578, 0.089905, -0.006794, 0.066101, -0.051117, 0.088684, -0.002989, + -0.066895, 0.089844, 0.012131, -0.020203, 0.011230, 0.000327, 0.073669, 0.060669, + 0.091064, 0.075989, 0.051971, 0.045044, 0.033875, 0.040466, -0.029449, 0.128418, + -0.000229, -0.026901, 0.052063, 0.000995, -0.032532, 0.105896, -0.001241, 0.114075, + 0.047607, 0.090332, 0.063660, 0.016495, 0.124817, 0.090942, 0.021545, 0.007164, + 0.074890, 0.118347, 0.047394, 0.052856, 0.104980, 0.009384, 0.034363, 0.019073, + 0.072388, -0.013313, 0.119141, 0.021255, 0.103210, 0.058319, 0.186035, -0.010818, + 0.037109, -0.044037, -0.075989, -0.001281, 0.017899, 0.030701, -0.080261, 0.082703 +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) { + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) { + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load input features + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + } + + // Accumulate: layer input channels (if layer_idx > 0) + let prev_channels = IN_CHANNELS - 8u; + for (var i: u32 = 0u; i < prev_channels; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + } + } + } + + output[c] = max(0.0, sum); // ReLU + } + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +} diff --git a/workspaces/main/shaders/cnn_v2_layer_2.wgsl b/workspaces/main/shaders/cnn_v2_layer_2.wgsl new file mode 100644 index 0000000..2f9836a --- /dev/null +++ b/workspaces/main/shaders/cnn_v2_layer_2.wgsl @@ -0,0 +1,156 @@ +// CNN v2 Layer 2 - Auto-generated +// Kernel: 3×3, In: 12, Out: 4 + +const KERNEL_SIZE: u32 = 3u; +const IN_CHANNELS: u32 = 12u; +const OUT_CHANNELS: u32 = 4u; +const KERNEL_RADIUS: i32 = 1; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, 432> = array( + 0.030212, -0.041351, 0.053864, -0.025635, 0.099976, -0.016830, -0.068665, 0.112488, + -0.069824, 0.030197, 0.020142, 0.101807, 0.061920, 0.022415, -0.025864, -0.056366, + 0.085571, -0.053650, 0.109802, 0.129272, 0.023438, 0.087341, 0.066284, 0.037079, + -0.067566, 0.021530, -0.046814, 0.029343, -0.028534, 0.047150, -0.079346, -0.022675, + -0.019669, -0.024185, 0.029587, 0.068970, 0.108826, 0.050598, -0.072144, 0.083008, + -0.002201, 0.006275, 0.056396, 0.001884, 0.097168, -0.028503, -0.002499, 0.008919, + -0.013771, -0.017502, -0.033478, 0.105530, 0.032898, 0.068726, -0.036285, -0.021011, + -0.018250, 0.073914, 0.024277, 0.061066, 0.008682, -0.022766, 0.074219, 0.094421, + 0.050903, 0.072571, 0.117493, -0.033234, 0.067993, -0.008049, 0.046997, -0.064209, + -0.381104, 0.107788, -0.213867, 0.145142, 0.514160, 0.407715, -0.317871, 0.249023, + 0.055634, -0.006294, -0.067444, 0.025131, 0.012939, -0.074158, -0.013741, -0.033020, + 0.026871, -0.007671, 0.089661, -0.003016, 0.029007, -0.038483, 0.045044, 0.104065, + 0.077148, 0.092468, -0.090027, -0.048126, 0.096863, -0.088013, 0.009483, 0.075012, + -0.076843, -0.085449, -0.066040, 0.019165, -0.019958, 0.083496, 0.069275, -0.019714, + 0.027786, -0.042389, 0.054718, 0.010635, -0.071777, 0.029282, -0.003605, 0.113770, + 0.080994, 0.106079, 0.047333, -0.013733, 0.034760, 0.099365, -0.020813, 0.095886, + 0.052490, -0.049194, 0.047394, 0.072510, -0.030930, -0.003782, -0.038025, -0.019318, + -0.047852, -0.043915, 0.026810, -0.041138, 0.038422, 0.009605, -0.080688, -0.019653, + 0.075256, -0.013817, -0.022400, 0.050629, 0.048462, 0.072998, -0.009109, 0.070923, + 0.079895, 0.071350, 0.002869, 0.081543, 0.037231, 0.020767, -0.017929, 0.042328, + -0.075134, -0.010681, -0.009079, 0.057007, -0.040253, -0.025574, -0.041534, 0.105835, + -0.039703, 0.032104, 0.076050, 0.070923, -0.013046, -0.054108, -0.024582, -0.033997, + 0.092285, 0.000525, 0.114685, 0.036926, -0.419434, 0.087891, -0.187866, 0.128906, + 0.665527, 0.268311, -0.337891, 0.195557, 0.140503, 0.014465, -0.043671, 0.031677, + 0.073059, 0.085144, 0.014290, -0.046967, 0.033356, 0.004177, 0.102844, 0.015259, + 0.026627, -0.005032, 0.111694, -0.010590, 0.029816, 0.108154, -0.072327, 0.056213, + 0.022903, 0.053772, 0.084473, -0.059845, -0.032776, -0.000015, -0.093872, -0.085815, + 0.081604, 0.069336, 0.034149, -0.067322, -0.020859, 0.120911, 0.077209, -0.016388, + 0.050140, -0.045563, -0.046326, 0.032623, -0.005009, 0.008003, 0.109192, 0.086548, + 0.096558, 0.118530, 0.035034, 0.110352, -0.041748, 0.009178, 0.049957, 0.084839, + 0.042053, -0.069153, -0.024796, -0.094604, -0.047028, -0.053802, 0.024979, 0.049591, + -0.016373, -0.047607, -0.008797, -0.058868, 0.107178, 0.055695, 0.092407, 0.092346, + 0.053894, 0.054657, -0.039703, -0.073792, 0.041779, -0.044159, 0.099182, 0.037109, + 0.097778, 0.098206, -0.057831, -0.054016, -0.068604, -0.061584, -0.054382, 0.005268, + 0.096008, -0.007118, -0.063049, 0.059113, 0.076904, 0.045288, -0.055695, -0.052612, + -0.022110, 0.049103, 0.095276, 0.014572, 0.064819, 0.014671, 0.029800, 0.066284, + -0.383301, 0.071838, -0.207275, 0.099365, 0.640137, 0.393311, -0.334229, 0.275391, + -0.013977, -0.025269, -0.007065, -0.033478, -0.017349, 0.026764, 0.005192, 0.093384, + 0.014313, 0.018906, 0.006962, 0.094849, 0.005390, 0.101624, -0.041199, 0.026245, + 0.027588, 0.062408, 0.033356, -0.010826, 0.067993, -0.054199, 0.076416, 0.023315, + -0.002886, -0.112061, -0.041473, -0.012703, 0.016022, 0.010506, -0.021362, -0.037750, + 0.062927, 0.061920, 0.038177, -0.037201, -0.011620, 0.014015, -0.062164, -0.045441, + -0.063416, -0.040100, 0.035950, 0.045563, -0.017227, -0.060547, -0.017593, 0.111877, + 0.121521, 0.073853, 0.023331, -0.012428, 0.018478, -0.010948, 0.030716, 0.043427, + 0.003117, -0.069092, 0.038361, -0.053497, 0.039154, -0.085754, 0.012642, -0.051208, + 0.022934, 0.127197, 0.117920, 0.074036, 0.083313, -0.061951, 0.079224, 0.091248, + 0.009132, 0.069946, 0.123474, 0.130127, 0.118835, 0.020874, -0.045380, -0.000111, + 0.111206, 0.054688, 0.008995, 0.085693, 0.005562, 0.103088, -0.034698, 0.119934, + -0.067200, 0.065430, -0.021942, 0.089783, 0.033112, -0.025467, 0.040161, -0.052155, + -0.048920, 0.031250, 0.112549, 0.122192, 0.126587, 0.180908, 0.194946, 0.121704, + 0.217529, 0.224243, 0.269287, 0.222656, 0.288086, 0.035492, 0.066711, -0.046600, + 0.085144, 0.013855, -0.065979, -0.083252, -0.058289, 0.104126, 0.013702, -0.018188, + 0.036591, 0.099854, 0.056061, 0.151855, 0.062134, 0.133789, 0.084045, 0.095825, + 0.036987, 0.022308, 0.070923, 0.031036, 0.101868, 0.062347, 0.141235, 0.066650 +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) { + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) { + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load input features + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + } + + // Accumulate: layer input channels (if layer_idx > 0) + let prev_channels = IN_CHANNELS - 8u; + for (var i: u32 = 0u; i < prev_channels; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + } + } + } + + output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation + } + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +} |
