summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--PROJECT_CONTEXT.md4
-rw-r--r--TODO.md2
-rw-r--r--cnn_v3/docs/HOWTO.md60
-rw-r--r--cnn_v3/shaders/cnn_v3_bottleneck.wgsl73
-rw-r--r--cnn_v3/shaders/cnn_v3_common.wgsl23
-rw-r--r--cnn_v3/shaders/cnn_v3_dec0.wgsl73
-rw-r--r--cnn_v3/shaders/cnn_v3_dec1.wgsl72
-rw-r--r--cnn_v3/shaders/cnn_v3_enc0.wgsl75
-rw-r--r--cnn_v3/shaders/cnn_v3_enc1.wgsl88
-rw-r--r--src/effects/shaders.cc8
-rw-r--r--workspaces/main/assets.txt8
11 files changed, 480 insertions, 6 deletions
diff --git a/PROJECT_CONTEXT.md b/PROJECT_CONTEXT.md
index cadd514..f42ccf4 100644
--- a/PROJECT_CONTEXT.md
+++ b/PROJECT_CONTEXT.md
@@ -36,7 +36,7 @@
- **Audio:** Sample-accurate sync. Zero heap allocations per frame. Variable tempo. OLA-IDCT synthesis (v2 .spec): Hann analysis window, rectangular synthesis, 50% overlap, click-free. V1 (raw DCT-512) preserved for generated notes. .spec files regenerated as v2.
- **Shaders:** Parameterized effects (UniformHelper, .seq syntax). Beat-synchronized animation support (`beat_time`, `beat_phase`). Modular WGSL composition with ShaderComposer. 27 shared common shaders (math, render, compute). Reusable snippets: `render/scratch_lines`, `render/ntsc_common` (NTSC signal processing, RGB and YIQ input variants via `sample_ntsc_signal` hook), `math/color` (YIQ/NTSC), `math/color_c64` (C64 palette, Bayer dither, border animation).
- **3D:** Hybrid SDF/rasterization with BVH. Binary scene loader. Blender pipeline.
-- **Effects:** CNN post-processing: CNNEffect (v1) and CNNv2Effect operational. CNN v2: sigmoid activation, storage buffer weights (~3.2 KB), 7D static features, dynamic layers. Training stable, convergence validated. **CNN v3 Phase 1 complete:** `GBufferEffect` integrated (MRT raster + pack compute, 20-channel feature textures). See `cnn_v3/docs/HOWTO.md`.
+- **Effects:** CNN post-processing: CNNEffect (v1) and CNNv2Effect operational. CNN v2: sigmoid activation, storage buffer weights (~3.2 KB), 7D static features, dynamic layers. Training stable, convergence validated. **CNN v3 Phase 3 complete:** 5 WGSL inference shaders (enc0/enc1/bottleneck/dec1/dec0) + `cnn_v3/common` snippet. Zero-pad convs, AvgPool down, NearestUp, FiLM, skip-concat, sigmoid output. Registered in assets + shaders.cc. See `cnn_v3/docs/HOWTO.md` §7.
- **Tools:** CNN test tool operational. Texture readback utility functional. Timeline editor (web-based, beat-aligned, audio playback).
- **Build:** Asset dependency tracking. Size measurement. Hot-reload (debug-only). WSL (Windows 10) supported: native Linux build and cross-compile to `.exe` via `mingw-w64`.
- **Sequence:** DAG-based effect routing with explicit node system. Python compiler with topological sort and ping-pong optimization. 12 effects operational (Passthrough, Placeholder, GaussianBlur, Heptagon, Particles, RotatingCube, Hybrid3D, Flash, PeakMeter, Scene1, Scene2, Scratch). Effect times are absolute (seq_compiler adds sequence start offset). See `doc/SEQUENCE.md`.
@@ -46,7 +46,7 @@
## Next Up
-**Active:** CNN v3 Phase 3 (WGSL U-Net shaders), Spectral Brush Editor
+**Active:** CNN v3 Phase 4 (C++ CNNv3Effect + FiLM uniform), Spectral Brush Editor
**Ongoing:** Test infrastructure maintenance (35/35 passing)
**Future:** Size optimization (64k target), 3D enhancements
diff --git a/TODO.md b/TODO.md
index 1c405ef..86c3e37 100644
--- a/TODO.md
+++ b/TODO.md
@@ -75,7 +75,7 @@ PyTorch / HTML WebGPU / C++ WebGPU.
- SDF/shadow passes TODO (placeholder: shadow=1, transp=0)
- Howto: `cnn_v3/docs/HOWTO.md`
2. ✅ Training infrastructure: `blender_export.py`, `pack_blender_sample.py`, `pack_photo_sample.py`
-3. WGSL shaders (enc/dec/bottleneck, FiLM, deterministic ops)
+3. ✅ WGSL shaders: cnn_v3_common (snippet), enc0, enc1, bottleneck, dec1, dec0
4. C++ CNNv3Effect + FiLM uniform upload
5. Parity validation (test vectors, ≤1/255 per pixel)
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 88d4bbc..ad71f1f 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -200,13 +200,67 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++.
| 1 — G-buffer (raster + pack) | ✅ Done | Integrated, 35/35 tests pass |
| 1 — G-buffer (SDF + shadow passes) | TODO | Placeholder in place |
| 2 — Training infrastructure | ✅ Done | blender_export.py, pack_*_sample.py |
-| 3 — WGSL U-Net shaders | TODO | enc/dec/bottleneck/FiLM |
+| 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet |
| 4 — C++ CNNv3Effect | TODO | FiLM uniform upload |
| 5 — Parity validation | TODO | Test vectors, ≤1/255 |
---
-## 7. Quick Troubleshooting
+## 7. CNN v3 Inference Shaders (Phase 3)
+
+Five compute passes, each a standalone WGSL shader using `#include "cnn_v3/common"`.
+The common snippet provides `get_w()` and `unpack_8ch()`.
+
+| Pass | Shader | Input(s) | Output | Dims |
+|------|--------|----------|--------|------|
+| enc0 | `cnn_v3_enc0.wgsl` | feat_tex0+feat_tex1 (20ch) | enc0_tex rgba16float (4ch) | full |
+| enc1 | `cnn_v3_enc1.wgsl` | enc0_tex (AvgPool2×2 inline) | enc1_tex rgba32uint (8ch) | ½ |
+| bottleneck | `cnn_v3_bottleneck.wgsl` | enc1_tex (AvgPool2×2 inline) | bottleneck_tex rgba32uint (8ch) | ¼ |
+| dec1 | `cnn_v3_dec1.wgsl` | bottleneck_tex + enc1_tex (skip) | dec1_tex rgba16float (4ch) | ½ |
+| dec0 | `cnn_v3_dec0.wgsl` | dec1_tex + enc0_tex (skip) | output_tex rgba16float (4ch) | full |
+
+**Parity rules baked into the shaders:**
+- Zero-padding (not clamp) at conv borders
+- AvgPool 2×2 for downsampling (exact, deterministic)
+- Nearest-neighbor for upsampling (integer `coord / 2`)
+- Skip connections: channel concatenation (not add)
+- FiLM applied after conv+bias, before ReLU: `max(0, γ·x + β)`
+- No batch norm at inference
+- Weight layout: OIHW (out × in × kH × kW), biases after conv weights
+
+**Params uniform per shader** (`group 0, binding 3`):
+```
+struct Params {
+ weight_offset: u32, // f16 index into shared weights buffer
+ _pad: vec3u,
+ gamma: vec4f, // FiLM γ (enc1: gamma_lo+gamma_hi for 8ch)
+ beta: vec4f, // FiLM β (enc1: beta_lo+beta_hi for 8ch)
+}
+```
+FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each frame.
+
+**Weight offsets** (f16 units, including bias):
+| Layer | Weights | Bias | Total f16 |
+|-------|---------|------|-----------|
+| enc0 | 20×4×9=720 | +4 | 724 |
+| enc1 | 4×8×9=288 | +8 | 296 |
+| bottleneck | 8×8×1=64 | +8 | 72 |
+| dec1 | 16×4×9=576 | +4 | 580 |
+| dec0 | 8×4×9=288 | +4 | 292 |
+| **Total** | | | **2064 f16 = ~4 KB** |
+
+**Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`):
+`SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`,
+`SHADER_CNN_V3_BOTTLENECK`, `SHADER_CNN_V3_DEC1`, `SHADER_CNN_V3_DEC0`
+
+**C++ usage (Phase 4):**
+```cpp
+auto src = ShaderComposer::Get().Compose({"cnn_v3/common"}, raw_wgsl);
+```
+
+---
+
+## 8. Quick Troubleshooting
**GBufferEffect renders nothing / albedo is black**
- Check `set_scene()` was called before `render()`
@@ -227,7 +281,7 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++.
---
-## See Also
+## 9. See Also
- `cnn_v3/docs/CNN_V3.md` — Full architecture design (U-Net, FiLM, feature layout)
- `doc/EFFECT_WORKFLOW.md` — General effect integration guide
diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
new file mode 100644
index 0000000..909fd41
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
@@ -0,0 +1,73 @@
+// CNN v3 — Bottleneck
+// AvgPool2x2(enc1) + Conv(8->8, 1x1) + ReLU (no FiLM)
+//
+// Input: enc1_tex (rgba32uint, 8xf16) half-res
+// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 8*8*1) conv: w[out][in] (1x1 kernel)
+// [64 .. +8) bias: b[out]
+
+#include "cnn_v3/common"
+
+const BN_IN: u32 = 8u;
+const BN_OUT: u32 = 8u;
+
+struct Params {
+ weight_offset: u32,
+ _pad0: u32,
+ _pad1: u32,
+ _pad2: u32,
+}
+
+@group(0) @binding(0) var enc1_tex: texture_2d<u32>;
+@group(0) @binding(1) var<storage, read> weights: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+@group(0) @binding(3) var bottleneck_out: texture_storage_2d<rgba32uint, write>;
+
+// Avg-pool 2x2 from enc1_tex at quarter-res coord qcoord.
+// Returns zeros for OOB quarter-res coords (zero-padding for the 1x1 conv).
+fn load_enc1_avg(qcoord: vec2i, half_dims: vec2i) -> array<f32, 8> {
+ let quart_dims = half_dims / 2;
+ if (qcoord.x < 0 || qcoord.y < 0 || qcoord.x >= quart_dims.x || qcoord.y >= quart_dims.y) {
+ return array<f32, 8>(0., 0., 0., 0., 0., 0., 0., 0.);
+ }
+ let base = qcoord * 2;
+ var s: array<f32, BN_IN>;
+ for (var dy: i32 = 0; dy < 2; dy++) {
+ for (var dx: i32 = 0; dx < 2; dx++) {
+ let hc = clamp(base + vec2i(dx, dy), vec2i(0), half_dims - vec2i(1));
+ let f = unpack_8ch(enc1_tex, hc);
+ for (var i: u32 = 0u; i < BN_IN; i++) { s[i] += f[i]; }
+ }
+ }
+ for (var i: u32 = 0u; i < BN_IN; i++) { s[i] *= 0.25; }
+ return s;
+}
+
+@compute @workgroup_size(8, 8)
+fn bottleneck_main(@builtin(global_invocation_id) id: vec3u) {
+ let half_dims = vec2i(textureDimensions(enc1_tex));
+ let quart_dims = half_dims / 2;
+ let coord = vec2i(id.xy);
+ if (coord.x >= quart_dims.x || coord.y >= quart_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ let feat = load_enc1_avg(coord, half_dims);
+ var out: array<f32, BN_OUT>;
+
+ for (var o: u32 = 0u; o < BN_OUT; o++) {
+ var sum = get_w(wo, BN_OUT * BN_IN + o); // bias (1x1 kernel: no spatial idx)
+ for (var i: u32 = 0u; i < BN_IN; i++) {
+ sum += get_w(wo, o * BN_IN + i) * feat[i];
+ }
+ out[o] = max(0.0, sum);
+ }
+
+ textureStore(bottleneck_out, coord, vec4u(
+ pack2x16float(vec2f(out[0], out[1])),
+ pack2x16float(vec2f(out[2], out[3])),
+ pack2x16float(vec2f(out[4], out[5])),
+ pack2x16float(vec2f(out[6], out[7]))
+ ));
+}
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl
new file mode 100644
index 0000000..54b0f3d
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_common.wgsl
@@ -0,0 +1,23 @@
+// CNN v3 shared helpers — included by all inference compute shaders.
+// Requires the host shader to declare:
+// @group(?) @binding(?) var<storage, read> weights: array<u32>;
+
+// Read one f16 value from the packed-f16 weights buffer.
+// `base` — weight_offset from Params (f16 index of the layer start)
+// `idx` — local f16 index within the layer (conv weight or bias)
+fn get_w(base: u32, idx: u32) -> f32 {
+ let i = base + idx;
+ let v = unpack2x16float(weights[i >> 1u]);
+ return select(v.y, v.x, (i & 1u) == 0u);
+}
+
+// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
+// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
+fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
+ let t = textureLoad(tex, coord, 0);
+ let v0 = unpack2x16float(t.x);
+ let v1 = unpack2x16float(t.y);
+ let v2 = unpack2x16float(t.z);
+ let v3 = unpack2x16float(t.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
diff --git a/cnn_v3/shaders/cnn_v3_dec0.wgsl b/cnn_v3/shaders/cnn_v3_dec0.wgsl
new file mode 100644
index 0000000..7a4e7c9
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_dec0.wgsl
@@ -0,0 +1,73 @@
+// CNN v3 — Decoder level 0 + output
+// NearestUp2x(dec1) + cat(enc0_skip) -> Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + Sigmoid
+//
+// Inputs: dec1_tex (rgba16float, 4ch) half-res
+// enc0_tex (rgba16float, 4ch) full-res (skip connection)
+// Output: output_tex (rgba16float, 4ch) full-res (dispatch at full-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 8*4*9) conv: w[out][in][ky][kx] (in=8: 4 dec1 + 4 enc0 skip)
+// [288 .. +4) bias: b[out]
+//
+// Parity note: sigmoid applied directly to dec0 output (matches train_cnn_v3.py forward()).
+
+#include "cnn_v3/common"
+
+const DEC0_IN: u32 = 8u;
+const DEC0_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var dec1_tex: texture_2d<f32>;
+@group(0) @binding(1) var enc0_tex: texture_2d<f32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var output_tex: texture_storage_2d<rgba16float, write>;
+
+// Load 8 concatenated channels at full-res coord:
+// ch 0-3: dec1 nearest-up (dec1_tex[coord/2])
+// ch 4-7: enc0 skip (enc0_tex[coord])
+// Returns zeros for OOB coord (zero-padding for the conv).
+fn load_dec0_concat(coord: vec2i, full_dims: vec2i) -> array<f32, 8> {
+ if (coord.x < 0 || coord.y < 0 || coord.x >= full_dims.x || coord.y >= full_dims.y) {
+ return array<f32, 8>(0., 0., 0., 0., 0., 0., 0., 0.);
+ }
+ let half_dims = vec2i(textureDimensions(dec1_tex));
+ let hc = clamp(coord / 2, vec2i(0), half_dims - vec2i(1));
+ let d = textureLoad(dec1_tex, hc, 0);
+ let e = textureLoad(enc0_tex, coord, 0);
+ return array<f32, 8>(d.x, d.y, d.z, d.w, e.x, e.y, e.z, e.w);
+}
+
+@compute @workgroup_size(8, 8)
+fn dec0_main(@builtin(global_invocation_id) id: vec3u) {
+ let full_dims = vec2i(textureDimensions(enc0_tex));
+ let coord = vec2i(id.xy);
+ if (coord.x >= full_dims.x || coord.y >= full_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, DEC0_OUT>;
+
+ for (var o: u32 = 0u; o < DEC0_OUT; o++) {
+ var sum = get_w(wo, DEC0_OUT * DEC0_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_dec0_concat(coord + vec2i(kx, ky), full_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < DEC0_IN; i++) {
+ sum += get_w(wo, o * DEC0_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ // FiLM + ReLU + Sigmoid (matches training forward())
+ let v = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ out[o] = 1.0 / (1.0 + exp(-v));
+ }
+
+ textureStore(output_tex, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_dec1.wgsl b/cnn_v3/shaders/cnn_v3_dec1.wgsl
new file mode 100644
index 0000000..28ae3dc
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_dec1.wgsl
@@ -0,0 +1,72 @@
+// CNN v3 — Decoder level 1
+// NearestUp2x(bottleneck) + cat(enc1_skip) -> Conv(16->4, 3x3, zero-pad) + FiLM + ReLU
+//
+// Inputs: bottleneck_tex (rgba32uint, 8xf16) quarter-res
+// enc1_tex (rgba32uint, 8xf16) half-res (skip connection)
+// Output: dec1_out (rgba16float, 4ch) half-res (dispatch at half-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 16*4*9) conv: w[out][in][ky][kx] (in=16: 8 bottleneck + 8 enc1 skip)
+// [576 .. +4) bias: b[out]
+
+#include "cnn_v3/common"
+
+const DEC1_IN: u32 = 16u;
+const DEC1_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var bottleneck_tex: texture_2d<u32>;
+@group(0) @binding(1) var enc1_tex: texture_2d<u32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var dec1_out: texture_storage_2d<rgba16float, write>;
+
+// Load 16 concatenated channels at half-res coord hcoord:
+// ch 0-7: bottleneck nearest-up (bottleneck_tex[hcoord/2])
+// ch 8-15: enc1 skip (enc1_tex[hcoord])
+// Returns zeros for OOB hcoord (zero-padding for the conv).
+fn load_dec1_concat(hcoord: vec2i, half_dims: vec2i) -> array<f32, 16> {
+ if (hcoord.x < 0 || hcoord.y < 0 || hcoord.x >= half_dims.x || hcoord.y >= half_dims.y) {
+ return array<f32, 16>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);
+ }
+ let quart_dims = half_dims / 2;
+ let qc = clamp(hcoord / 2, vec2i(0), quart_dims - vec2i(1));
+ let b = unpack_8ch(bottleneck_tex, qc);
+ let s = unpack_8ch(enc1_tex, hcoord);
+ return array<f32, 16>(
+ b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
+ s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7]
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn dec1_main(@builtin(global_invocation_id) id: vec3u) {
+ let half_dims = vec2i(textureDimensions(enc1_tex));
+ let coord = vec2i(id.xy);
+ if (coord.x >= half_dims.x || coord.y >= half_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, DEC1_OUT>;
+
+ for (var o: u32 = 0u; o < DEC1_OUT; o++) {
+ var sum = get_w(wo, DEC1_OUT * DEC1_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_dec1_concat(coord + vec2i(kx, ky), half_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < DEC1_IN; i++) {
+ sum += get_w(wo, o * DEC1_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ }
+
+ textureStore(dec1_out, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl
new file mode 100644
index 0000000..f52a167
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl
@@ -0,0 +1,75 @@
+// CNN v3 — Encoder level 0
+// Conv(20->4, 3x3, zero-pad) + FiLM + ReLU
+//
+// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res
+// Output: enc0_out (rgba16float, 4ch) full-res
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 20*4*9) conv: w[out][in][ky][kx]
+// [720 .. +4) bias: b[out]
+
+#include "cnn_v3/common"
+
+const ENC0_IN: u32 = 20u;
+const ENC0_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var feat_tex0: texture_2d<u32>;
+@group(0) @binding(1) var feat_tex1: texture_2d<u32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var enc0_out: texture_storage_2d<rgba16float, write>;
+
+// Unpack all 20 feature channels at coord. Returns zeros for OOB (zero-padding).
+fn load_feat(coord: vec2i, dims: vec2i) -> array<f32, 20> {
+ if (coord.x < 0 || coord.y < 0 || coord.x >= dims.x || coord.y >= dims.y) {
+ return array<f32, 20>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);
+ }
+ let t0 = textureLoad(feat_tex0, coord, 0);
+ let t1 = textureLoad(feat_tex1, coord, 0);
+ let a = unpack2x16float(t0.x);
+ let b = unpack2x16float(t0.y);
+ let c = unpack2x16float(t0.z);
+ let d = unpack2x16float(t0.w);
+ let e = unpack4x8unorm(t1.x);
+ let f = unpack4x8unorm(t1.y);
+ let g = unpack4x8unorm(t1.z);
+ return array<f32, 20>(
+ a.x, a.y, b.x, b.y, c.x, c.y, d.x, d.y,
+ e.x, e.y, e.z, e.w,
+ f.x, f.y, f.z, f.w,
+ g.x, g.y, g.z, g.w
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn enc0_main(@builtin(global_invocation_id) id: vec3u) {
+ let coord = vec2i(id.xy);
+ let dims = vec2i(textureDimensions(feat_tex0));
+ if (coord.x >= dims.x || coord.y >= dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, ENC0_OUT>;
+
+ for (var o: u32 = 0u; o < ENC0_OUT; o++) {
+ var sum = get_w(wo, ENC0_OUT * ENC0_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_feat(coord + vec2i(kx, ky), dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < ENC0_IN; i++) {
+ sum += get_w(wo, o * ENC0_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ }
+
+ textureStore(enc0_out, coord, vec4f(out[0], out[1], out[2], out[3]));
+}
diff --git a/cnn_v3/shaders/cnn_v3_enc1.wgsl b/cnn_v3/shaders/cnn_v3_enc1.wgsl
new file mode 100644
index 0000000..23e485d
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_enc1.wgsl
@@ -0,0 +1,88 @@
+// CNN v3 — Encoder level 1
+// AvgPool2x2(enc0) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU
+//
+// Input: enc0_tex (rgba16float, 4ch) full-res
+// Output: enc1_out (rgba32uint, 8xf16) half-res (dispatch at half-res dims)
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 4*8*9) conv: w[out][in][ky][kx]
+// [288 .. +8) bias: b[out]
+
+#include "cnn_v3/common"
+
+const ENC1_IN: u32 = 4u;
+const ENC1_OUT: u32 = 8u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma_lo: vec4f, // FiLM gamma ch 0-3
+ gamma_hi: vec4f, // FiLM gamma ch 4-7
+ beta_lo: vec4f, // FiLM beta ch 0-3
+ beta_hi: vec4f, // FiLM beta ch 4-7
+}
+
+@group(0) @binding(0) var enc0_tex: texture_2d<f32>;
+@group(0) @binding(1) var<storage, read> weights: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+@group(0) @binding(3) var enc1_out: texture_storage_2d<rgba32uint, write>;
+
+fn film_gamma(o: u32) -> f32 {
+ if (o < 4u) { return params.gamma_lo[o]; }
+ return params.gamma_hi[o - 4u];
+}
+fn film_beta(o: u32) -> f32 {
+ if (o < 4u) { return params.beta_lo[o]; }
+ return params.beta_hi[o - 4u];
+}
+
+// Avg-pool 2x2 from enc0_tex at half-res coord hcoord.
+// Returns zeros for OOB half-res coords (zero-padding for the conv).
+fn load_enc0_avg(hcoord: vec2i, full_dims: vec2i) -> array<f32, 4> {
+ let half_dims = full_dims / 2;
+ if (hcoord.x < 0 || hcoord.y < 0 || hcoord.x >= half_dims.x || hcoord.y >= half_dims.y) {
+ return array<f32, 4>(0., 0., 0., 0.);
+ }
+ let base = hcoord * 2;
+ var s = vec4f(0.);
+ for (var dy: i32 = 0; dy < 2; dy++) {
+ for (var dx: i32 = 0; dx < 2; dx++) {
+ let fc = clamp(base + vec2i(dx, dy), vec2i(0), full_dims - vec2i(1));
+ s += textureLoad(enc0_tex, fc, 0);
+ }
+ }
+ let avg = s * 0.25;
+ return array<f32, 4>(avg.x, avg.y, avg.z, avg.w);
+}
+
+@compute @workgroup_size(8, 8)
+fn enc1_main(@builtin(global_invocation_id) id: vec3u) {
+ let full_dims = vec2i(textureDimensions(enc0_tex));
+ let half_dims = full_dims / 2;
+ let coord = vec2i(id.xy);
+ if (coord.x >= half_dims.x || coord.y >= half_dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, ENC1_OUT>;
+
+ for (var o: u32 = 0u; o < ENC1_OUT; o++) {
+ var sum = get_w(wo, ENC1_OUT * ENC1_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_enc0_avg(coord + vec2i(kx, ky), full_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < ENC1_IN; i++) {
+ sum += get_w(wo, o * ENC1_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, film_gamma(o) * sum + film_beta(o));
+ }
+
+ textureStore(enc1_out, coord, vec4u(
+ pack2x16float(vec2f(out[0], out[1])),
+ pack2x16float(vec2f(out[2], out[3])),
+ pack2x16float(vec2f(out[4], out[5])),
+ pack2x16float(vec2f(out[6], out[7]))
+ ));
+}
diff --git a/src/effects/shaders.cc b/src/effects/shaders.cc
index 1adbff5..22c6a6d 100644
--- a/src/effects/shaders.cc
+++ b/src/effects/shaders.cc
@@ -68,6 +68,14 @@ void InitShaderComposer() {
AssetId::ASSET_SHADER_RENDER_RAYMARCHING);
register_if_exists("render/raymarching_id",
AssetId::ASSET_SHADER_RENDER_RAYMARCHING_ID);
+ // CNN v3 inference snippets
+ register_if_exists("cnn_v3/common", AssetId::ASSET_SHADER_CNN_V3_COMMON);
+ register_if_exists("cnn_v3/enc0", AssetId::ASSET_SHADER_CNN_V3_ENC0);
+ register_if_exists("cnn_v3/enc1", AssetId::ASSET_SHADER_CNN_V3_ENC1);
+ register_if_exists("cnn_v3/bottleneck", AssetId::ASSET_SHADER_CNN_V3_BOTTLENECK);
+ register_if_exists("cnn_v3/dec1", AssetId::ASSET_SHADER_CNN_V3_DEC1);
+ register_if_exists("cnn_v3/dec0", AssetId::ASSET_SHADER_CNN_V3_DEC0);
+
// CNN shaders (workspace-specific)
// register_if_exists("cnn_activation", AssetId::ASSET_SHADER_CNN_ACTIVATION);
// register_if_exists("cnn_conv1x1", AssetId::ASSET_SHADER_CNN_CONV1X1);
diff --git a/workspaces/main/assets.txt b/workspaces/main/assets.txt
index ad57d2f..4cb4f40 100644
--- a/workspaces/main/assets.txt
+++ b/workspaces/main/assets.txt
@@ -101,6 +101,14 @@ SHADER_RENDER_NTSC_COMMON, WGSL, ../../src/shaders/render/ntsc_common.wgsl, "NTS
# --- CNN v3 G-Buffer ---
SHADER_GBUF_RASTER, WGSL, ../../cnn_v3/shaders/gbuf_raster.wgsl, "CNN v3 G-buffer MRT rasterization shader"
SHADER_GBUF_PACK, WGSL, ../../cnn_v3/shaders/gbuf_pack.wgsl, "CNN v3 G-buffer feature pack compute shader"
+
+# --- CNN v3 Inference ---
+SHADER_CNN_V3_COMMON, WGSL, ../../cnn_v3/shaders/cnn_v3_common.wgsl, "CNN v3 shared helpers snippet (get_w, unpack_8ch)"
+SHADER_CNN_V3_ENC0, WGSL, ../../cnn_v3/shaders/cnn_v3_enc0.wgsl, "CNN v3 encoder level 0"
+SHADER_CNN_V3_ENC1, WGSL, ../../cnn_v3/shaders/cnn_v3_enc1.wgsl, "CNN v3 encoder level 1"
+SHADER_CNN_V3_BOTTLENECK, WGSL, ../../cnn_v3/shaders/cnn_v3_bottleneck.wgsl, "CNN v3 bottleneck"
+SHADER_CNN_V3_DEC1, WGSL, ../../cnn_v3/shaders/cnn_v3_dec1.wgsl, "CNN v3 decoder level 1"
+SHADER_CNN_V3_DEC0, WGSL, ../../cnn_v3/shaders/cnn_v3_dec0.wgsl, "CNN v3 decoder level 0 + sigmoid output"
SHADER_DEBUG_DEBUG_PRINT, WGSL, ../../src/shaders/debug/debug_print.wgsl, "Debug print snippet"
# --- Sequence Shaders ---