summaryrefslogtreecommitdiff
path: root/cnn_v3/docs/CNN_V3.md
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
commit8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch)
tree2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/docs/CNN_V3.md
parent4ca498277b033ae10134045dae9c8c249a8d2b2b (diff)
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/docs/CNN_V3.md')
-rw-r--r--cnn_v3/docs/CNN_V3.md87
1 files changed, 45 insertions, 42 deletions
diff --git a/cnn_v3/docs/CNN_V3.md b/cnn_v3/docs/CNN_V3.md
index d775e2b..a197a1d 100644
--- a/cnn_v3/docs/CNN_V3.md
+++ b/cnn_v3/docs/CNN_V3.md
@@ -19,7 +19,7 @@ CNN v3 is a next-generation post-processing effect using:
- Training from both Blender renders and real photos
- Strict test framework: per-pixel bit-exact validation across all implementations
-**Status:** Phases 1–5 complete. Parity validated (max_err=4.88e-4 ≤ 1/255). Next: `train_cnn_v3.py` for FiLM MLP training.
+**Status:** Phases 1–7 complete. Architecture upgraded to enc_channels=[8,16] for improved capacity. Parity test and runtime updated. Next: training pass.
---
@@ -52,14 +52,14 @@ A small MLP takes a conditioning vector `c` and outputs all γ/β:
c = [beat_phase, beat_time/8, audio_intensity, style_p0, style_p1] (5D)
↓ Linear(5 → 16) → ReLU
↓ Linear(16 → N_film_params)
- → [γ_enc0(4ch), β_enc0(4ch), γ_enc1(8ch), β_enc1(8ch),
- γ_dec1(4ch), β_dec1(4ch), γ_dec0(4ch), β_dec0(4ch)]
- = 2 × (4+8+4+4) = 40 parameters output
+ → [γ_enc0(8ch), β_enc0(8ch), γ_enc1(16ch), β_enc1(16ch),
+ γ_dec1(8ch), β_dec1(8ch), γ_dec0(4ch), β_dec0(4ch)]
+ = 2 × (8+16+8+4) = 72 parameters output
```
**Runtime cost:** trivial (one MLP forward pass per frame, CPU-side).
**Training:** jointly trained with U-Net — backprop through FiLM to MLP.
-**Size:** MLP weights ~(5×16 + 16×40) × 2 bytes f16 ≈ 1.4 KB.
+**Size:** MLP weights ~(5×16 + 16×72) × 2 bytes f16 ≈ 2.5 KB.
**Why FiLM instead of just uniform parameters?**
- γ/β are per-channel, enabling fine-grained style control
@@ -318,22 +318,25 @@ All f16, little-endian, same packing as v2 (`pack2x16float`).
## Size Budget
-**CNN v3 target: ≤ 6 KB weights**
+**CNN v3 target: ≤ 6 KB weights (conv only); current arch prioritises quality**
-**Implemented architecture (fits ≤ 4 KB):**
+**Implemented architecture (enc_channels=[8,16] — ~15.3 KB conv f16):**
| Component | Weights | Bias | Total f16 |
|-----------|---------|------|-----------|
-| enc0: Conv(20→4, 3×3) | 20×4×9=720 | +4 | 724 |
-| enc1: Conv(4→8, 3×3) | 4×8×9=288 | +8 | 296 |
-| bottleneck: Conv(8→8, 3×3, dil=2) | 8×8×9=576 | +8 | 584 |
-| dec1: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 |
-| dec0: Conv(8→4, 3×3) | 8×4×9=288 | +4 | 292 |
-| FiLM MLP (5→16→40) | 5×16+16×40=720 | +16+40 | 776 |
-| **Total conv** | | | **~4.84 KB f16** |
+| enc0: Conv(20→8, 3×3) | 20×8×9=1440 | +8 | 1448 |
+| enc1: Conv(8→16, 3×3) | 8×16×9=1152 | +16 | 1168 |
+| bottleneck: Conv(16→16, 3×3, dil=2) | 16×16×9=2304 | +16 | 2320 |
+| dec1: Conv(32→8, 3×3) | 32×8×9=2304 | +8 | 2312 |
+| dec0: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 |
+| **Total conv** | | | **7828 f16 = ~15.3 KB** |
+| FiLM MLP (5→16→72) | 5×16+16×72=1232 | +16+72 | 1320 |
+| **Total incl. MLP** | | | **9148 f16 = ~17.9 KB** |
-Skip connections: dec1 input = 8ch (bottleneck) + 8ch (enc1 skip) = 16ch.
-dec0 input = 4ch (dec1) + 4ch (enc0 skip) = 8ch.
+Skip connections: dec1 input = 16ch (bottleneck up) + 16ch (enc1 skip) = 32ch.
+dec0 input = 8ch (dec1 up) + 8ch (enc0 skip) = 16ch.
+
+**Smaller variant (enc_channels=[4,8] — ~4.84 KB conv f16):** fits 6 KB target but has lower representational capacity. Train with `--enc-channels 4,8` if size-critical.
---
@@ -507,7 +510,7 @@ All tests: max per-pixel per-channel absolute error ≤ 1/255 (PyTorch f32 vs We
```python
class CNNv3(nn.Module):
- def __init__(self, enc_channels=[4,8], film_cond_dim=5):
+ def __init__(self, enc_channels=[8,16], film_cond_dim=5):
super().__init__()
# Encoder
self.enc = nn.ModuleList([
@@ -681,11 +684,11 @@ Parity results:
```
Pass 0: pack_gbuffer.wgsl — assemble G-buffer channels into storage texture
-Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→4ch, 3×3)
-Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (4→8ch, 3×3) + downsample
-Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 3×3, dilation=2)
-Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (16→4, 3×3)
-Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (8→4, 3×3)
+Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→8ch, 3×3)
+Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (8→16ch, 3×3) + downsample
+Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (16→16, 3×3, dilation=2)
+Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (32→8, 3×3)
+Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (16→4, 3×3)
Pass 6: cnn_v3_output.wgsl — sigmoid + composite to framebuffer
```
@@ -788,11 +791,11 @@ Status bar shows which channels are loaded.
| Shader | Replaces | Notes |
|--------|----------|-------|
| `PACK_SHADER` | `STATIC_SHADER` | 20ch into feat_tex0 + feat_tex1 (rgba32uint each) |
-| `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→4, 3×3) + FiLM + ReLU; writes enc0_tex |
-| `ENC1_SHADER` | | Conv(4→8, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_tex (half-res) |
-| `BOTTLENECK_SHADER` | | Conv(8→8, 3×3, dilation=2) + ReLU; writes bn_tex |
-| `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(16→4, 3×3) + FiLM + ReLU |
-| `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(8→4, 3×3) + FiLM + ReLU |
+| `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→8, 3×3) + FiLM + ReLU; writes enc0_tex (rgba32uint, 8ch) |
+| `ENC1_SHADER` | | Conv(8→16, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_lo+enc1_hi (2× rgba32uint, 16ch split) |
+| `BOTTLENECK_SHADER` | | Conv(16→16, 3×3, dilation=2) + ReLU; writes bn_lo+bn_hi (2× rgba32uint, 16ch split) |
+| `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(32→8, 3×3) + FiLM + ReLU; writes dec1_tex (rgba32uint, 8ch) |
+| `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(16→4, 3×3) + FiLM + ReLU; writes rgba16float |
| `OUTPUT_SHADER` | | Conv(4→4, 1×1) + sigmoid → composites to canvas |
FiLM γ/β computed JS-side from sliders (tiny MLP forward pass in JS), uploaded as uniform.
@@ -805,15 +808,15 @@ FiLM γ/β computed JS-side from sliders (tiny MLP forward pass in JS), uploaded
|------|------|--------|----------|
| `feat_tex0` | W×H | rgba32uint | feature buffer slots 0–7 (f16) |
| `feat_tex1` | W×H | rgba32uint | feature buffer slots 8–19 (u8+spare) |
-| `enc0_tex` | W×H | rgba32uint | 4 channels f16 (enc0 output, skip) |
-| `enc1_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (enc1 out, skip) — 2 texels per pixel |
-| `bn_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (bottleneck output) |
-| `dec1_tex` | W×H | rgba32uint | 4 channels f16 (dec1 output) |
-| `dec0_tex` | W×H | rgba32uint | 4 channels f16 (dec0 output) |
+| `enc0_tex` | W×H | rgba32uint | 8 channels f16 (enc0 output, skip) |
+| `enc1_lo` + `enc1_hi` | W/2×H/2 each | rgba32uint | 16 channels f16 split (enc1 out, skip) |
+| `bn_lo` + `bn_hi` | W/4×H/4 each | rgba32uint | 16 channels f16 split (bottleneck output) |
+| `dec1_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (dec1 output) |
+| `dec0_tex` | W×H | rgba16float | 4 channels f16 (final RGBA output) |
| `prev_tex` | W×H | rgba16float | previous CNN output (temporal, `F16X8`) |
-Skip connections: enc0_tex and enc1_tex are **kept alive** across the full forward pass
-(not ping-ponged away). DEC1 and DEC0 read them directly.
+Skip connections: enc0_tex (8ch) and enc1_lo/enc1_hi (16ch split) are **kept alive** across the
+full forward pass (not ping-ponged away). DEC1 and DEC0 read them directly.
---
@@ -856,7 +859,7 @@ python3 -m http.server 8000
Ordered for parallel execution where possible. Phases 1 and 2 are independent.
-**Architecture locked:** enc_channels = [4, 8]. See Size Budget for weight counts.
+**Architecture:** enc_channels = [8, 16]. See Size Budget for weight counts.
---
@@ -881,7 +884,7 @@ before the real G-buffer exists. Wire real G-buffer in Phase 5.
**1a. PyTorch model**
- [ ] `cnn_v3/training/train_cnn_v3.py`
- - [ ] `CNNv3` class: U-Net [4,8], FiLM MLP (5→16→48), channel dropout
+ - [ ] `CNNv3` class: U-Net [8,16], FiLM MLP (5→16→72), channel dropout
- [ ] `GBufferDataset`: loads 20-channel feature tensors from packed PNGs
- [ ] Training loop, checkpointing, grayscale/RGBA loss option
@@ -919,11 +922,11 @@ no batch norm at inference, `#include` existing snippets where possible.
- writes feat_tex0 (f16×8) + feat_tex1 (u8×12, spare)
**2b. U-Net compute shaders**
-- [ ] `src/effects/cnn_v3_enc0.wgsl` — Conv(20→4, 3×3) + FiLM + ReLU
-- [ ] `src/effects/cnn_v3_enc1.wgsl` — Conv(4→8, 3×3) + FiLM + ReLU + avg_pool 2×2
-- [ ] `src/effects/cnn_v3_bottleneck.wgsl` — Conv(8→8, 1×1) + FiLM + ReLU
-- [ ] `src/effects/cnn_v3_dec1.wgsl` — nearest upsample×2 + concat enc1_skip + Conv(16→4, 3×3) + FiLM + ReLU
-- [ ] `src/effects/cnn_v3_dec0.wgsl` — nearest upsample×2 + concat enc0_skip + Conv(8→4, 3×3) + FiLM + ReLU
+- [ ] `src/effects/cnn_v3_enc0.wgsl` — Conv(20→8, 3×3) + FiLM + ReLU
+- [ ] `src/effects/cnn_v3_enc1.wgsl` — Conv(8→16, 3×3) + FiLM + ReLU + avg_pool 2×2
+- [ ] `src/effects/cnn_v3_bottleneck.wgsl` — Conv(16→16, 3×3, dilation=2) + ReLU
+- [ ] `src/effects/cnn_v3_dec1.wgsl` — nearest upsample×2 + concat enc1_skip + Conv(32→8, 3×3) + FiLM + ReLU
+- [ ] `src/effects/cnn_v3_dec0.wgsl` — nearest upsample×2 + concat enc0_skip + Conv(16→4, 3×3) + FiLM + ReLU
- [ ] `src/effects/cnn_v3_output.wgsl` — Conv(4→4, 1×1) + sigmoid → composite to framebuffer
Reuse from existing shaders:
@@ -941,7 +944,7 @@ Reuse from existing shaders:
- [ ] `src/effects/cnn_v3_effect.h` — class declaration
- textures: feat_tex0, feat_tex1, enc0_tex, enc1_tex (half-res), bn_tex (half-res), dec1_tex, dec0_tex
- **`WGPUTexture prev_cnn_tex_`** — persistent RGBA8, owned by effect, initialized black
- - `FilmParams` uniform buffer (γ/β for 4 levels = 48 floats = 192 bytes)
+ - `FilmParams` uniform buffer (γ/β for 4 levels = 72 floats = 288 bytes)
- FiLM MLP weights (loaded from .bin, run CPU-side per frame)
- [ ] `src/effects/cnn_v3_effect.cc` — implementation