summaryrefslogtreecommitdiff
path: root/cnn_v3/shaders/cnn_v3_common.wgsl
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 08:38:29 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 08:38:29 +0100
commita4ff60233fce134e8f779ef001872dfd9a8f9923 (patch)
tree3a5466273ecb42269b4d6443c893c61b84ee7d93 /cnn_v3/shaders/cnn_v3_common.wgsl
parent4d055080d2ab4b674d5f0fd611ea051e87454a31 (diff)
feat(cnn_v3): Phase 3 complete — WGSL U-Net inference shaders
5 compute shaders + cnn_v3/common snippet: enc0: Conv(20→4,3×3) + FiLM + ReLU full-res enc1: AvgPool + Conv(4→8,3×3) + FiLM + ReLU half-res bottleneck: AvgPool + Conv(8→8,1×1) + ReLU quarter-res dec1: NearestUp + cat(enc1) + Conv(16→4) + FiLM half-res dec0: NearestUp + cat(enc0) + Conv(8→4) + FiLM + Sigmoid full-res Parity rules: zero-pad conv, AvgPool down, NearestUp, FiLM after conv+bias, skip=concat, OIHW weights+bias layout. Matches PyTorch train_cnn_v3.py forward() exactly. Registered in workspaces/main/assets.txt + src/effects/shaders.cc. Weight layout + Params struct documented in cnn_v3/docs/HOWTO.md §7. Next: Phase 4 — C++ CNNv3Effect + FiLM uniform upload. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3/shaders/cnn_v3_common.wgsl')
-rw-r--r--cnn_v3/shaders/cnn_v3_common.wgsl23
1 files changed, 23 insertions, 0 deletions
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl
new file mode 100644
index 0000000..54b0f3d
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_common.wgsl
@@ -0,0 +1,23 @@
+// CNN v3 shared helpers — included by all inference compute shaders.
+// Requires the host shader to declare:
+// @group(?) @binding(?) var<storage, read> weights: array<u32>;
+
+// Read one f16 value from the packed-f16 weights buffer.
+// `base` — weight_offset from Params (f16 index of the layer start)
+// `idx` — local f16 index within the layer (conv weight or bias)
+fn get_w(base: u32, idx: u32) -> f32 {
+ let i = base + idx;
+ let v = unpack2x16float(weights[i >> 1u]);
+ return select(v.y, v.x, (i & 1u) == 0u);
+}
+
+// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
+// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
+fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
+ let t = textureLoad(tex, coord, 0);
+ let v0 = unpack2x16float(t.x);
+ let v1 = unpack2x16float(t.y);
+ let v2 = unpack2x16float(t.z);
+ let v3 = unpack2x16float(t.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}