summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 08:38:29 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 08:38:29 +0100
commita4ff60233fce134e8f779ef001872dfd9a8f9923 (patch)
tree3a5466273ecb42269b4d6443c893c61b84ee7d93 /cnn_v3
parent4d055080d2ab4b674d5f0fd611ea051e87454a31 (diff)
feat(cnn_v3): Phase 3 complete — WGSL U-Net inference shaders
5 compute shaders + cnn_v3/common snippet: enc0: Conv(20→4,3×3) + FiLM + ReLU full-res enc1: AvgPool + Conv(4→8,3×3) + FiLM + ReLU half-res bottleneck: AvgPool + Conv(8→8,1×1) + ReLU quarter-res dec1: NearestUp + cat(enc1) + Conv(16→4) + FiLM half-res dec0: NearestUp + cat(enc0) + Conv(8→4) + FiLM + Sigmoid full-res Parity rules: zero-pad conv, AvgPool down, NearestUp, FiLM after conv+bias, skip=concat, OIHW weights+bias layout. Matches PyTorch train_cnn_v3.py forward() exactly. Registered in workspaces/main/assets.txt + src/effects/shaders.cc. Weight layout + Params struct documented in cnn_v3/docs/HOWTO.md §7. Next: Phase 4 — C++ CNNv3Effect + FiLM uniform upload. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/docs/HOWTO.md60
-rw-r--r--cnn_v3/shaders/cnn_v3_bottleneck.wgsl73
-rw-r--r--cnn_v3/shaders/cnn_v3_common.wgsl23
-rw-r--r--cnn_v3/shaders/cnn_v3_dec0.wgsl73
-rw-r--r--cnn_v3/shaders/cnn_v3_dec1.wgsl72
-rw-r--r--cnn_v3/shaders/cnn_v3_enc0.wgsl75
-rw-r--r--cnn_v3/shaders/cnn_v3_enc1.wgsl88
7 files changed, 461 insertions, 3 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 88d4bbc..ad71f1f 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -200,13 +200,67 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++.
| 1 — G-buffer (raster + pack) | ✅ Done | Integrated, 35/35 tests pass |
| 1 — G-buffer (SDF + shadow passes) | TODO | Placeholder in place |
| 2 — Training infrastructure | ✅ Done | blender_export.py, pack_*_sample.py |
-| 3 — WGSL U-Net shaders | TODO | enc/dec/bottleneck/FiLM |
+| 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet |
| 4 — C++ CNNv3Effect | TODO | FiLM uniform upload |
| 5 — Parity validation | TODO | Test vectors, ≤1/255 |
---
-## 7. Quick Troubleshooting
+## 7. CNN v3 Inference Shaders (Phase 3)
+
+Five compute passes, each a standalone WGSL shader using `#include "cnn_v3/common"`.
+The common snippet provides `get_w()` and `unpack_8ch()`.
+
+| Pass | Shader | Input(s) | Output | Dims |
+|------|--------|----------|--------|------|
+| enc0 | `cnn_v3_enc0.wgsl` | feat_tex0+feat_tex1 (20ch) | enc0_tex rgba16float (4ch) | full |
+| enc1 | `cnn_v3_enc1.wgsl` | enc0_tex (AvgPool2×2 inline) | enc1_tex rgba32uint (8ch) | ½ |
+| bottleneck | `cnn_v3_bottleneck.wgsl` | enc1_tex (AvgPool2×2 inline) | bottleneck_tex rgba32uint (8ch) | ¼ |
+| dec1 | `cnn_v3_dec1.wgsl` | bottleneck_tex + enc1_tex (skip) | dec1_tex rgba16float (4ch) | ½ |
+| dec0 | `cnn_v3_dec0.wgsl` | dec1_tex + enc0_tex (skip) | output_tex rgba16float (4ch) | full |
+
+**Parity rules baked into the shaders:**
+- Zero-padding (not clamp) at conv borders
+- AvgPool 2×2 for downsampling (exact, deterministic)
+- Nearest-neighbor for upsampling (integer `coord / 2`)
+- Skip connections: channel concatenation (not add)
+- FiLM applied after conv+bias, before ReLU: `max(0, γ·x + β)`
+- No batch norm at inference
+- Weight layout: OIHW (out × in × kH × kW), biases after conv weights
+
+**Params uniform per shader** (`group 0, binding 3`):
+```
+struct Params {
+ weight_offset: u32, // f16 index into shared weights buffer
+ _pad: vec3u,
+ gamma: vec4f, // FiLM γ (enc1: gamma_lo+gamma_hi for 8ch)
+ beta: vec4f, // FiLM β (enc1: beta_lo+beta_hi for 8ch)
+}
+```
+FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each frame.
+
+**Weight offsets** (f16 units, including bias):
+| Layer | Weights | Bias | Total f16 |
+|-------|---------|------|-----------|
+| enc0 | 20×4×9=720 | +4 | 724 |
+| enc1 | 4×8×9=288 | +8 | 296 |
+| bottleneck | 8×8×1=64 | +8 | 72 |
+| dec1 | 16×4×9=576 | +4 | 580 |
+| dec0 | 8×4×9=288 | +4 | 292 |
+| **Total** | | | **2064 f16 = ~4 KB** |
+
+**Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`):
+`SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`,
+`SHADER_CNN_V3_BOTTLENECK`, `SHADER_CNN_V3_DEC1`, `SHADER_CNN_V3_DEC0`
+
+**C++ usage (Phase 4):**
+```cpp
+auto src = ShaderComposer::Get().Compose({"cnn_v3/common"}, raw_wgsl);
+```
+
+---
+
+## 8. Quick Troubleshooting
**GBufferEffect renders nothing / albedo is black**
- Check `set_scene()` was called before `render()`
@@ -227,7 +281,7 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++.
---
-## See Also
+## 9. See Also
- `cnn_v3/docs/CNN_V3.md` — Full architecture design (U-Net, FiLM, feature layout)
- `doc/EFFECT_WORKFLOW.md` — General effect integration guide
diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
new file mode 100644
index 0000000..909fd41
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
@@ -0,0 +1,73 @@
+// CNN v3 — Bottleneck
+// AvgPool2x2(enc1) + Conv(8->8, 1x1) + ReLU (no FiLM)
+//
+// Input: enc1_tex (rgba32uint, 8xf16) half-res
+// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 8*8*1) conv: w[out][in] (1x1 kernel)
+// [64 .. +8) bias: b[out]
+
+#include "cnn_v3/common"
+
+const BN_IN: u32 = 8u;
+const BN_OUT: u32 = 8u;
+
+struct Params {
+ weight_offset: u32,
+ _pad0: u32,
+ _pad1: u32,
+ _pad2: u32,
+}
+
+@group(0) @binding(0) var enc1_tex: texture_2d<u32>;
+@group(0) @binding(1) var<storage, read> weights: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+@group(0) @binding(3) var bottleneck_out: texture_storage_2d<rgba32uint, write>;
+
+// Avg-pool 2x2 from enc1_tex at quarter-res coord qcoord.
+// Returns zeros for OOB quarter-res coords (zero-padding for the 1x1 conv).
+fn load_enc1_avg(qcoord: vec2i, half_dims: vec2i) -> array<f32, 8> {
+ let quart_dims = half_dims / 2;
+ if (qcoord.x < 0 || qcoord.y < 0 || qcoord.x >= quart_dims.x || qcoord.y >= quart_dims.y) {
+ return array<f32, 8>(0., 0., 0., 0., 0., 0., 0., 0.);
+ }
+ let base = qcoord * 2;
+ var s: array<f32, BN_IN>;
+ for (var dy: i32 = 0; dy < 2; dy++) {
+ for (var dx: i32 = 0; dx < 2; dx++) {
+ let hc = clamp(base + vec2i(dx, dy), vec2i(0), half_dims - vec2i(1));
+ let f = unpack_8ch(enc1_tex, hc);
+ for (var i: u32 = 0u; i < BN_IN; i++) { s[i] += f[i]; }
+ }
+ }
+ for (var i: u32 = 0u; i < BN_IN; i++) { s[i] *= 0.25; }
+ return s;
+}
+
+@compute @workgroup_size(8, 8)
+fn bottleneck_main(@builtin(global_invocation_id) id: vec3u) {
+ let half_dims = vec2i(textureDimensions(enc1_tex));
+ let quart_dims = half_dims / 2;
+ let coord = vec2i(id.xy);
+ if (coord.x >= quart_dims.x || coord.y >= quart_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ let feat = load_enc1_avg(coord, half_dims);
+ var out: array<f32, BN_OUT>;
+
+ for (var o: u32 = 0u; o < BN_OUT; o++) {
+ var sum = get_w(wo, BN_OUT * BN_IN + o); // bias (1x1 kernel: no spatial idx)
+ for (var i: u32 = 0u; i < BN_IN; i++) {
+ sum += get_w(wo, o * BN_IN + i) * feat[i];
+ }
+ out[o] = max(0.0, sum);
+ }
+
+ textureStore(bottleneck_out, coord, vec4u(
+ pack2x16float(vec2f(out[0], out[1])),
+ pack2x16float(vec2f(out[2], out[3])),
+ pack2x16float(vec2f(out[4], out[5])),
+ pack2x16float(vec2f(out[6], out[7]))
+ ));
+}
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl
new file mode 100644
index 0000000..54b0f3d
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_common.wgsl
@@ -0,0 +1,23 @@
+// CNN v3 shared helpers — included by all inference compute shaders.
+// Requires the host shader to declare:
+// @group(?) @binding(?) var<storage, read> weights: array<u32>;
+
+// Read one f16 value from the packed-f16 weights buffer.
+// `base` — weight_offset from Params (f16 index of the layer start)
+// `idx` — local f16 index within the layer (conv weight or bias)
+fn get_w(base: u32, idx: u32) -> f32 {
+ let i = base + idx;
+ let v = unpack2x16float(weights[i >> 1u]);
+ return select(v.y, v.x, (i & 1u) == 0u);
+}
+
+// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
+// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
+fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
+ let t = textureLoad(tex, coord, 0);
+ let v0 = unpack2x16float(t.x);
+ let v1 = unpack2x16float(t.y);
+ let v2 = unpack2x16float(t.z);
+ let v3 = unpack2x16float(t.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
diff --git a/cnn_v3/shaders/cnn_v3_dec0.wgsl b/cnn_v3/shaders/cnn_v3_dec0.wgsl
new file mode 100644
index 0000000..7a4e7c9
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_dec0.wgsl
@@ -0,0 +1,73 @@
+// CNN v3 — Decoder level 0 + output
+// NearestUp2x(dec1) + cat(enc0_skip) -> Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + Sigmoid
+//
+// Inputs: dec1_tex (rgba16float, 4ch) half-res
+// enc0_tex (rgba16float, 4ch) full-res (skip connection)
+// Output: output_tex (rgba16float, 4ch) full-res (dispatch at full-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 8*4*9) conv: w[out][in][ky][kx] (in=8: 4 dec1 + 4 enc0 skip)
+// [288 .. +4) bias: b[out]
+//
+// Parity note: sigmoid applied directly to dec0 output (matches train_cnn_v3.py forward()).
+
+#include "cnn_v3/common"
+
+const DEC0_IN: u32 = 8u;
+const DEC0_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var dec1_tex: texture_2d<f32>;
+@group(0) @binding(1) var enc0_tex: texture_2d<f32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var output_tex: texture_storage_2d<rgba16float, write>;
+
+// Load 8 concatenated channels at full-res coord:
+// ch 0-3: dec1 nearest-up (dec1_tex[coord/2])
+// ch 4-7: enc0 skip (enc0_tex[coord])
+// Returns zeros for OOB coord (zero-padding for the conv).
+fn load_dec0_concat(coord: vec2i, full_dims: vec2i) -> array<f32, 8> {
+ if (coord.x < 0 || coord.y < 0 || coord.x >= full_dims.x || coord.y >= full_dims.y) {
+ return array<f32, 8>(0., 0., 0., 0., 0., 0., 0., 0.);
+ }
+ let half_dims = vec2i(textureDimensions(dec1_tex));
+ let hc = clamp(coord / 2, vec2i(0), half_dims - vec2i(1));
+ let d = textureLoad(dec1_tex, hc, 0);
+ let e = textureLoad(enc0_tex, coord, 0);
+ return array<f32, 8>(d.x, d.y, d.z, d.w, e.x, e.y, e.z, e.w);
+}
+
+@compute @workgroup_size(8, 8)
+fn dec0_main(@builtin(global_invocation_id) id: vec3u) {
+ let full_dims = vec2i(textureDimensions(enc0_tex));
+ let coord = vec2i(id.xy);
+ if (coord.x >= full_dims.x || coord.y >= full_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, DEC0_OUT>;
+
+ for (var o: u32 = 0u; o < DEC0_OUT; o++) {
+ var sum = get_w(wo, DEC0_OUT * DEC0_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_dec0_concat(coord + vec2i(kx, ky), full_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < DEC0_IN; i++) {
+ sum += get_w(wo, o * DEC0_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ // FiLM + ReLU + Sigmoid (matches training forward())
+ let v = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ out[o] = 1.0 / (1.0 + exp(-v));
+ }
+
+ textureStore(output_tex, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_dec1.wgsl b/cnn_v3/shaders/cnn_v3_dec1.wgsl
new file mode 100644
index 0000000..28ae3dc
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_dec1.wgsl
@@ -0,0 +1,72 @@
+// CNN v3 — Decoder level 1
+// NearestUp2x(bottleneck) + cat(enc1_skip) -> Conv(16->4, 3x3, zero-pad) + FiLM + ReLU
+//
+// Inputs: bottleneck_tex (rgba32uint, 8xf16) quarter-res
+// enc1_tex (rgba32uint, 8xf16) half-res (skip connection)
+// Output: dec1_out (rgba16float, 4ch) half-res (dispatch at half-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 16*4*9) conv: w[out][in][ky][kx] (in=16: 8 bottleneck + 8 enc1 skip)
+// [576 .. +4) bias: b[out]
+
+#include "cnn_v3/common"
+
+const DEC1_IN: u32 = 16u;
+const DEC1_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var bottleneck_tex: texture_2d<u32>;
+@group(0) @binding(1) var enc1_tex: texture_2d<u32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var dec1_out: texture_storage_2d<rgba16float, write>;
+
+// Load 16 concatenated channels at half-res coord hcoord:
+// ch 0-7: bottleneck nearest-up (bottleneck_tex[hcoord/2])
+// ch 8-15: enc1 skip (enc1_tex[hcoord])
+// Returns zeros for OOB hcoord (zero-padding for the conv).
+fn load_dec1_concat(hcoord: vec2i, half_dims: vec2i) -> array<f32, 16> {
+ if (hcoord.x < 0 || hcoord.y < 0 || hcoord.x >= half_dims.x || hcoord.y >= half_dims.y) {
+ return array<f32, 16>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);
+ }
+ let quart_dims = half_dims / 2;
+ let qc = clamp(hcoord / 2, vec2i(0), quart_dims - vec2i(1));
+ let b = unpack_8ch(bottleneck_tex, qc);
+ let s = unpack_8ch(enc1_tex, hcoord);
+ return array<f32, 16>(
+ b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
+ s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7]
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn dec1_main(@builtin(global_invocation_id) id: vec3u) {
+ let half_dims = vec2i(textureDimensions(enc1_tex));
+ let coord = vec2i(id.xy);
+ if (coord.x >= half_dims.x || coord.y >= half_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, DEC1_OUT>;
+
+ for (var o: u32 = 0u; o < DEC1_OUT; o++) {
+ var sum = get_w(wo, DEC1_OUT * DEC1_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_dec1_concat(coord + vec2i(kx, ky), half_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < DEC1_IN; i++) {
+ sum += get_w(wo, o * DEC1_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ }
+
+ textureStore(dec1_out, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl
new file mode 100644
index 0000000..f52a167
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl
@@ -0,0 +1,75 @@
+// CNN v3 — Encoder level 0
+// Conv(20->4, 3x3, zero-pad) + FiLM + ReLU
+//
+// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res
+// Output: enc0_out (rgba16float, 4ch) full-res
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 20*4*9) conv: w[out][in][ky][kx]
+// [720 .. +4) bias: b[out]
+
+#include "cnn_v3/common"
+
+const ENC0_IN: u32 = 20u;
+const ENC0_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var feat_tex0: texture_2d<u32>;
+@group(0) @binding(1) var feat_tex1: texture_2d<u32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var enc0_out: texture_storage_2d<rgba16float, write>;
+
+// Unpack all 20 feature channels at coord. Returns zeros for OOB (zero-padding).
+fn load_feat(coord: vec2i, dims: vec2i) -> array<f32, 20> {
+ if (coord.x < 0 || coord.y < 0 || coord.x >= dims.x || coord.y >= dims.y) {
+ return array<f32, 20>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);
+ }
+ let t0 = textureLoad(feat_tex0, coord, 0);
+ let t1 = textureLoad(feat_tex1, coord, 0);
+ let a = unpack2x16float(t0.x);
+ let b = unpack2x16float(t0.y);
+ let c = unpack2x16float(t0.z);
+ let d = unpack2x16float(t0.w);
+ let e = unpack4x8unorm(t1.x);
+ let f = unpack4x8unorm(t1.y);
+ let g = unpack4x8unorm(t1.z);
+ return array<f32, 20>(
+ a.x, a.y, b.x, b.y, c.x, c.y, d.x, d.y,
+ e.x, e.y, e.z, e.w,
+ f.x, f.y, f.z, f.w,
+ g.x, g.y, g.z, g.w
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn enc0_main(@builtin(global_invocation_id) id: vec3u) {
+ let coord = vec2i(id.xy);
+ let dims = vec2i(textureDimensions(feat_tex0));
+ if (coord.x >= dims.x || coord.y >= dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, ENC0_OUT>;
+
+ for (var o: u32 = 0u; o < ENC0_OUT; o++) {
+ var sum = get_w(wo, ENC0_OUT * ENC0_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_feat(coord + vec2i(kx, ky), dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < ENC0_IN; i++) {
+ sum += get_w(wo, o * ENC0_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ }
+
+ textureStore(enc0_out, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_enc1.wgsl b/cnn_v3/shaders/cnn_v3_enc1.wgsl
new file mode 100644
index 0000000..23e485d
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_enc1.wgsl
@@ -0,0 +1,88 @@
+// CNN v3 — Encoder level 1
+// AvgPool2x2(enc0) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU
+//
+// Input: enc0_tex (rgba16float, 4ch) full-res
+// Output: enc1_out (rgba32uint, 8xf16) half-res (dispatch at half-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 4*8*9) conv: w[out][in][ky][kx]
+// [288 .. +8) bias: b[out]
+
+#include "cnn_v3/common"
+
+const ENC1_IN: u32 = 4u;
+const ENC1_OUT: u32 = 8u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma_lo: vec4f, // FiLM gamma ch 0-3
+ gamma_hi: vec4f, // FiLM gamma ch 4-7
+ beta_lo: vec4f, // FiLM beta ch 0-3
+ beta_hi: vec4f, // FiLM beta ch 4-7
+}
+
+@group(0) @binding(0) var enc0_tex: texture_2d<f32>;
+@group(0) @binding(1) var<storage, read> weights: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+@group(0) @binding(3) var enc1_out: texture_storage_2d<rgba32uint, write>;
+
+fn film_gamma(o: u32) -> f32 {
+ if (o < 4u) { return params.gamma_lo[o]; }
+ return params.gamma_hi[o - 4u];
+}
+fn film_beta(o: u32) -> f32 {
+ if (o < 4u) { return params.beta_lo[o]; }
+ return params.beta_hi[o - 4u];
+}
+
+// Avg-pool 2x2 from enc0_tex at half-res coord hcoord.
+// Returns zeros for OOB half-res coords (zero-padding for the conv).
+fn load_enc0_avg(hcoord: vec2i, full_dims: vec2i) -> array<f32, 4> {
+ let half_dims = full_dims / 2;
+ if (hcoord.x < 0 || hcoord.y < 0 || hcoord.x >= half_dims.x || hcoord.y >= half_dims.y) {
+ return array<f32, 4>(0., 0., 0., 0.);
+ }
+ let base = hcoord * 2;
+ var s = vec4f(0.);
+ for (var dy: i32 = 0; dy < 2; dy++) {
+ for (var dx: i32 = 0; dx < 2; dx++) {
+ let fc = clamp(base + vec2i(dx, dy), vec2i(0), full_dims - vec2i(1));
+ s += textureLoad(enc0_tex, fc, 0);
+ }
+ }
+ let avg = s * 0.25;
+ return array<f32, 4>(avg.x, avg.y, avg.z, avg.w);
+}
+
+@compute @workgroup_size(8, 8)
+fn enc1_main(@builtin(global_invocation_id) id: vec3u) {
+ let full_dims = vec2i(textureDimensions(enc0_tex));
+ let half_dims = full_dims / 2;
+ let coord = vec2i(id.xy);
+ if (coord.x >= half_dims.x || coord.y >= half_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, ENC1_OUT>;
+
+ for (var o: u32 = 0u; o < ENC1_OUT; o++) {
+ var sum = get_w(wo, ENC1_OUT * ENC1_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_enc0_avg(coord + vec2i(kx, ky), full_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < ENC1_IN; i++) {
+ sum += get_w(wo, o * ENC1_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, film_gamma(o) * sum + film_beta(o));
+ }
+
+ textureStore(enc1_out, coord, vec4u(
+ pack2x16float(vec2f(out[0], out[1])),
+ pack2x16float(vec2f(out[2], out[3])),
+ pack2x16float(vec2f(out[4], out[5])),
+ pack2x16float(vec2f(out[6], out[7]))
+ ));
+}