summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn_v2_layer_0.wgsl
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 12:17:59 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 12:17:59 +0100
commitff4c1213636e66d4457a95cad12300c58e8d6781 (patch)
treeb47a9fc5c860c4eff39054b2ffc248ffbe19fa10 /workspaces/main/shaders/cnn_v2_layer_0.wgsl
parenteaf0bd855306e70ca03f2d6579b4d6551aff6482 (diff)
Refine training script output and validation
1. Loss printed at every epoch with \r (no scrolling) 2. Validation only on final epoch (not all checkpoints) 3. Process all input images (not just img_000.png) Training output now shows live progress with single line update.
Diffstat (limited to 'workspaces/main/shaders/cnn_v2_layer_0.wgsl')
-rw-r--r--workspaces/main/shaders/cnn_v2_layer_0.wgsl174
1 files changed, 174 insertions, 0 deletions
diff --git a/workspaces/main/shaders/cnn_v2_layer_0.wgsl b/workspaces/main/shaders/cnn_v2_layer_0.wgsl
new file mode 100644
index 0000000..8e14957
--- /dev/null
+++ b/workspaces/main/shaders/cnn_v2_layer_0.wgsl
@@ -0,0 +1,174 @@
+// CNN v2 Layer 0 - Auto-generated
+// Kernel: 3×3, In: 8, Out: 8
+
+const KERNEL_SIZE: u32 = 3u;
+const IN_CHANNELS: u32 = 8u;
+const OUT_CHANNELS: u32 = 8u;
+const KERNEL_RADIUS: i32 = 1;
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, 576> = array(
+ 0.057281, -0.041962, 0.003933, 0.026459, 0.304199, 0.067261, 0.191895, 0.047455,
+ 0.074402, 0.201660, 0.158325, 0.150513, 0.219238, 0.260010, 0.319336, 0.208618,
+ 0.050201, 0.090210, 0.086853, 0.181152, 0.060486, 0.167847, 0.161499, 0.265869,
+ 0.163818, 0.100647, 0.243408, -0.008553, -0.010849, 0.046509, -0.060608, -0.022263,
+ 0.094360, -0.043854, -0.005329, -0.093262, 0.032349, 0.007259, 0.039948, -0.018692,
+ -0.000618, 0.052368, -0.038055, 0.118042, -0.084595, 0.044281, -0.107056, 0.089478,
+ -0.076477, 0.017441, 0.088135, 0.076721, -0.063965, 0.001612, 0.062469, 0.067505,
+ 0.035736, 0.115051, -0.117737, -0.076843, -0.008888, -0.002028, -0.061005, 0.081726,
+ 0.115051, -0.028183, 0.043213, -0.079285, -0.040314, -0.047699, -0.051575, -0.052521,
+ 0.071533, 0.084656, 0.051910, 0.090637, -0.104248, -0.066467, -0.032104, -0.006977,
+ 0.075439, -0.004841, 0.084656, -0.034698, 0.035675, -0.101929, -0.035034, -0.036804,
+ 0.069641, -0.026840, -0.017807, -0.088318, -0.125000, -0.042847, -0.003063, 0.007622,
+ 0.076416, 0.094971, -0.019058, 0.083496, -0.085205, 0.036285, -0.077209, 0.082458,
+ 0.056549, 0.038818, 0.092224, -0.002499, 0.069641, 0.097229, 0.069275, -0.111084,
+ -0.092041, -0.020462, -0.061279, -0.032196, -0.088623, 0.032227, -0.117004, -0.125854,
+ -0.015884, 0.093018, -0.070923, -0.117615, -0.081848, -0.115479, 0.033508, -0.026443,
+ -0.009850, -0.063232, 0.098328, -0.000984, 0.039886, -0.085754, -0.108826, 0.030258,
+ 0.091675, 0.024384, -0.118958, -0.077148, -0.122437, -0.002090, -0.089539, 0.096741,
+ 0.095337, 0.108582, -0.101807, 0.152222, 0.206177, 0.050323, -0.111450, -0.104431,
+ -0.037445, 0.276611, 0.244019, 0.171143, 0.131592, 0.056030, 0.141602, 0.014267,
+ -0.025955, -0.019730, 0.155884, 0.072144, 0.176636, -0.010117, 0.141724, 0.103027,
+ -0.253174, -0.229370, -0.105713, -0.005898, 0.075439, -0.002014, -0.010506, -0.108093,
+ -0.016724, 0.108215, 0.053589, -0.044586, 0.030396, -0.077759, 0.058594, -0.018463,
+ 0.027100, 0.030823, -0.026947, -0.014084, 0.121643, 0.116638, -0.010239, 0.106262,
+ -0.109070, -0.044281, -0.045319, -0.021942, 0.083923, 0.114929, 0.154541, 0.078186,
+ -0.047394, 0.007957, 0.099182, -0.030075, 0.103699, 0.080994, -0.085144, 0.047180,
+ 0.099792, 0.081116, 0.084961, 0.151123, 0.000963, 0.029221, 0.073181, 0.086609,
+ 0.149048, -0.052185, -0.158936, 0.146240, 0.020004, 0.063110, 0.111877, 0.037201,
+ 0.087585, 0.134277, 0.058258, -0.075256, 0.141357, 0.045776, 0.171753, 0.186035,
+ 0.093201, 0.202637, 0.018723, -0.047638, 0.072510, 0.132812, 0.182251, 0.191650,
+ 0.163818, 0.146362, 0.124451, -0.082214, 0.094482, -0.007275, 0.029099, -0.040314,
+ -0.017624, -0.018860, -0.108398, -0.111145, 0.058289, -0.106995, -0.091919, 0.069824,
+ -0.084045, -0.105957, 0.065002, -0.012894, 0.042297, -0.081299, -0.112976, 0.012314,
+ 0.015625, -0.100708, -0.039673, 0.092041, 0.037201, 0.089722, 0.064087, 0.000403,
+ 0.120667, -0.012238, -0.055695, 0.010620, -0.022110, -0.008751, 0.038605, 0.075256,
+ 0.041260, 0.128296, -0.072021, 0.020828, -0.072449, 0.051239, 0.034058, 0.122803,
+ -0.062103, 0.156006, -0.111633, 0.043671, 0.209229, 0.006088, 0.141968, 0.209961,
+ 0.122620, -0.004547, 0.107727, 0.115601, 0.003378, 0.375732, 0.068481, 0.037842,
+ 0.159546, -0.014450, 0.073425, 0.168701, -0.052643, 0.060699, 0.333740, 0.033905,
+ -0.060150, 0.053558, 0.165527, -0.052460, -0.047882, 0.080750, 0.110352, -0.057098,
+ 0.057983, -0.018692, 0.019714, -0.056427, -0.053314, -0.001763, 0.027039, 0.003395,
+ -0.131226, -0.068481, -0.086609, 0.065186, 0.084717, 0.036530, 0.043488, 0.013893,
+ -0.076660, 0.081177, 0.037476, -0.124084, -0.070312, -0.027130, -0.009331, -0.128174,
+ -0.075256, 0.098206, -0.046539, -0.045319, 0.083923, -0.050598, 0.063477, 0.007408,
+ 0.026794, -0.090454, -0.083435, 0.129761, 0.044556, 0.051849, 0.115662, 0.071167,
+ 0.004414, 0.048035, -0.148682, 0.098938, 0.200562, 0.111938, 0.208496, 0.200684,
+ -0.050262, 0.119568, 0.062988, 0.072083, 0.123779, 0.369629, 0.317627, 0.187622,
+ 0.157227, 0.183960, 0.031921, 0.142944, 0.080627, 0.218628, 0.264160, 0.156128,
+ 0.084961, 0.029343, 0.057617, 0.089233, 0.041138, 0.044373, 0.074707, 0.025818,
+ 0.113708, -0.045380, -0.114929, 0.104370, -0.012238, -0.174194, -0.169312, -0.070312,
+ -0.005863, 0.027481, 0.053345, -0.016006, -0.057953, -0.010284, 0.034241, -0.041077,
+ -0.002373, 0.034515, 0.078552, -0.066162, -0.035400, 0.072510, 0.060425, -0.037720,
+ -0.025955, 0.118042, -0.071777, 0.133667, 0.012192, -0.080933, 0.093445, 0.052826,
+ -0.037354, -0.052277, 0.124084, 0.029861, 0.137085, 0.053009, -0.034180, -0.011421,
+ 0.089233, 0.172729, 0.146118, 0.003944, 0.279541, 0.162842, 0.112244, 0.204956,
+ 0.059753, 0.117737, 0.330322, 0.185547, 0.194946, 0.404541, 0.274658, 0.177612,
+ 0.153320, 0.189575, 0.032257, 0.285400, 0.158203, 0.048035, 0.476562, 0.301025,
+ -0.179565, 0.160767, 0.137207, 0.102478, -0.060547, 0.060364, -0.091858, 0.064209,
+ 0.082642, 0.044769, -0.096436, -0.103699, -0.021683, 0.007221, -0.048737, 0.071228,
+ -0.069580, 0.066528, -0.122864, -0.008415, -0.094788, 0.040131, -0.091431, -0.029602,
+ -0.112488, -0.074158, -0.004898, -0.006721, -0.118286, -0.047516, 0.069519, 0.121521,
+ -0.004158, 0.167603, -0.092468, -0.049927, 0.006599, 0.097595, 0.064087, 0.083435,
+ 0.026993, 0.071411, 0.020538, 0.022293, 0.022858, 0.124268, 0.098999, -0.031738,
+ 0.019806, -0.087341, -0.096558, -0.099304, -0.113159, 0.021744, -0.080200, -0.056030,
+ 0.089661, -0.055115, -0.115845, -0.040222, 0.035919, 0.027832, 0.034668, 0.072632,
+ 0.071838, -0.081116, 0.050262, -0.037872, 0.054047, -0.096680, -0.102051, -0.044281,
+ 0.078796, -0.095154, -0.013229, 0.031555, -0.058533, -0.114441, -0.008530, 0.112732,
+ -0.057251, 0.096191, -0.008385, 0.052246, -0.016983, 0.092041, 0.013710, 0.012299,
+ -0.109497, 0.025604, -0.121643, -0.023819, 0.039490, -0.090088, -0.013145, -0.101562,
+ -0.115051, 0.050232, -0.047119, -0.055847, -0.017563, 0.103760, 0.116333, -0.061768,
+ -0.083069, -0.030319, 0.078003, -0.010124, 0.044617, -0.045868, 0.103638, 0.032379,
+ -0.093506, -0.048004, -0.022079, -0.004353, -0.048187, -0.025330, -0.070740, -0.014671
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }
+ }
+ }
+
+ output[c] = max(0.0, sum); // ReLU
+ }
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}