From 8f14bdd66cb002b2f89265b2a578ad93249089c9 Mon Sep 17 00:00:00 2001 From: skal Date: Thu, 26 Mar 2026 07:03:01 +0100 Subject: feat(cnn_v3): upgrade architecture to enc_channels=[8,16] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass. --- cnn_v3/docs/CNN_V3.md | 87 ++++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 42 deletions(-) (limited to 'cnn_v3/docs/CNN_V3.md') diff --git a/cnn_v3/docs/CNN_V3.md b/cnn_v3/docs/CNN_V3.md index d775e2b..a197a1d 100644 --- a/cnn_v3/docs/CNN_V3.md +++ b/cnn_v3/docs/CNN_V3.md @@ -19,7 +19,7 @@ CNN v3 is a next-generation post-processing effect using: - Training from both Blender renders and real photos - Strict test framework: per-pixel bit-exact validation across all implementations -**Status:** Phases 1–5 complete. Parity validated (max_err=4.88e-4 ≤ 1/255). Next: `train_cnn_v3.py` for FiLM MLP training. +**Status:** Phases 1–7 complete. Architecture upgraded to enc_channels=[8,16] for improved capacity. Parity test and runtime updated. Next: training pass. --- @@ -52,14 +52,14 @@ A small MLP takes a conditioning vector `c` and outputs all γ/β: c = [beat_phase, beat_time/8, audio_intensity, style_p0, style_p1] (5D) ↓ Linear(5 → 16) → ReLU ↓ Linear(16 → N_film_params) - → [γ_enc0(4ch), β_enc0(4ch), γ_enc1(8ch), β_enc1(8ch), - γ_dec1(4ch), β_dec1(4ch), γ_dec0(4ch), β_dec0(4ch)] - = 2 × (4+8+4+4) = 40 parameters output + → [γ_enc0(8ch), β_enc0(8ch), γ_enc1(16ch), β_enc1(16ch), + γ_dec1(8ch), β_dec1(8ch), γ_dec0(4ch), β_dec0(4ch)] + = 2 × (8+16+8+4) = 72 parameters output ``` **Runtime cost:** trivial (one MLP forward pass per frame, CPU-side). **Training:** jointly trained with U-Net — backprop through FiLM to MLP. -**Size:** MLP weights ~(5×16 + 16×40) × 2 bytes f16 ≈ 1.4 KB. +**Size:** MLP weights ~(5×16 + 16×72) × 2 bytes f16 ≈ 2.5 KB. **Why FiLM instead of just uniform parameters?** - γ/β are per-channel, enabling fine-grained style control @@ -318,22 +318,25 @@ All f16, little-endian, same packing as v2 (`pack2x16float`). ## Size Budget -**CNN v3 target: ≤ 6 KB weights** +**CNN v3 target: ≤ 6 KB weights (conv only); current arch prioritises quality** -**Implemented architecture (fits ≤ 4 KB):** +**Implemented architecture (enc_channels=[8,16] — ~15.3 KB conv f16):** | Component | Weights | Bias | Total f16 | |-----------|---------|------|-----------| -| enc0: Conv(20→4, 3×3) | 20×4×9=720 | +4 | 724 | -| enc1: Conv(4→8, 3×3) | 4×8×9=288 | +8 | 296 | -| bottleneck: Conv(8→8, 3×3, dil=2) | 8×8×9=576 | +8 | 584 | -| dec1: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 | -| dec0: Conv(8→4, 3×3) | 8×4×9=288 | +4 | 292 | -| FiLM MLP (5→16→40) | 5×16+16×40=720 | +16+40 | 776 | -| **Total conv** | | | **~4.84 KB f16** | +| enc0: Conv(20→8, 3×3) | 20×8×9=1440 | +8 | 1448 | +| enc1: Conv(8→16, 3×3) | 8×16×9=1152 | +16 | 1168 | +| bottleneck: Conv(16→16, 3×3, dil=2) | 16×16×9=2304 | +16 | 2320 | +| dec1: Conv(32→8, 3×3) | 32×8×9=2304 | +8 | 2312 | +| dec0: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 | +| **Total conv** | | | **7828 f16 = ~15.3 KB** | +| FiLM MLP (5→16→72) | 5×16+16×72=1232 | +16+72 | 1320 | +| **Total incl. MLP** | | | **9148 f16 = ~17.9 KB** | -Skip connections: dec1 input = 8ch (bottleneck) + 8ch (enc1 skip) = 16ch. -dec0 input = 4ch (dec1) + 4ch (enc0 skip) = 8ch. +Skip connections: dec1 input = 16ch (bottleneck up) + 16ch (enc1 skip) = 32ch. +dec0 input = 8ch (dec1 up) + 8ch (enc0 skip) = 16ch. + +**Smaller variant (enc_channels=[4,8] — ~4.84 KB conv f16):** fits 6 KB target but has lower representational capacity. Train with `--enc-channels 4,8` if size-critical. --- @@ -507,7 +510,7 @@ All tests: max per-pixel per-channel absolute error ≤ 1/255 (PyTorch f32 vs We ```python class CNNv3(nn.Module): - def __init__(self, enc_channels=[4,8], film_cond_dim=5): + def __init__(self, enc_channels=[8,16], film_cond_dim=5): super().__init__() # Encoder self.enc = nn.ModuleList([ @@ -681,11 +684,11 @@ Parity results: ``` Pass 0: pack_gbuffer.wgsl — assemble G-buffer channels into storage texture -Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→4ch, 3×3) -Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (4→8ch, 3×3) + downsample -Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 3×3, dilation=2) -Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (16→4, 3×3) -Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (8→4, 3×3) +Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→8ch, 3×3) +Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (8→16ch, 3×3) + downsample +Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (16→16, 3×3, dilation=2) +Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (32→8, 3×3) +Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (16→4, 3×3) Pass 6: cnn_v3_output.wgsl — sigmoid + composite to framebuffer ``` @@ -788,11 +791,11 @@ Status bar shows which channels are loaded. | Shader | Replaces | Notes | |--------|----------|-------| | `PACK_SHADER` | `STATIC_SHADER` | 20ch into feat_tex0 + feat_tex1 (rgba32uint each) | -| `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→4, 3×3) + FiLM + ReLU; writes enc0_tex | -| `ENC1_SHADER` | | Conv(4→8, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_tex (half-res) | -| `BOTTLENECK_SHADER` | | Conv(8→8, 3×3, dilation=2) + ReLU; writes bn_tex | -| `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(16→4, 3×3) + FiLM + ReLU | -| `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(8→4, 3×3) + FiLM + ReLU | +| `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→8, 3×3) + FiLM + ReLU; writes enc0_tex (rgba32uint, 8ch) | +| `ENC1_SHADER` | | Conv(8→16, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_lo+enc1_hi (2× rgba32uint, 16ch split) | +| `BOTTLENECK_SHADER` | | Conv(16→16, 3×3, dilation=2) + ReLU; writes bn_lo+bn_hi (2× rgba32uint, 16ch split) | +| `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(32→8, 3×3) + FiLM + ReLU; writes dec1_tex (rgba32uint, 8ch) | +| `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(16→4, 3×3) + FiLM + ReLU; writes rgba16float | | `OUTPUT_SHADER` | | Conv(4→4, 1×1) + sigmoid → composites to canvas | FiLM γ/β computed JS-side from sliders (tiny MLP forward pass in JS), uploaded as uniform. @@ -805,15 +808,15 @@ FiLM γ/β computed JS-side from sliders (tiny MLP forward pass in JS), uploaded |------|------|--------|----------| | `feat_tex0` | W×H | rgba32uint | feature buffer slots 0–7 (f16) | | `feat_tex1` | W×H | rgba32uint | feature buffer slots 8–19 (u8+spare) | -| `enc0_tex` | W×H | rgba32uint | 4 channels f16 (enc0 output, skip) | -| `enc1_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (enc1 out, skip) — 2 texels per pixel | -| `bn_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (bottleneck output) | -| `dec1_tex` | W×H | rgba32uint | 4 channels f16 (dec1 output) | -| `dec0_tex` | W×H | rgba32uint | 4 channels f16 (dec0 output) | +| `enc0_tex` | W×H | rgba32uint | 8 channels f16 (enc0 output, skip) | +| `enc1_lo` + `enc1_hi` | W/2×H/2 each | rgba32uint | 16 channels f16 split (enc1 out, skip) | +| `bn_lo` + `bn_hi` | W/4×H/4 each | rgba32uint | 16 channels f16 split (bottleneck output) | +| `dec1_tex` | W/2×H/2 | rgba32uint | 8 channels f16 (dec1 output) | +| `dec0_tex` | W×H | rgba16float | 4 channels f16 (final RGBA output) | | `prev_tex` | W×H | rgba16float | previous CNN output (temporal, `F16X8`) | -Skip connections: enc0_tex and enc1_tex are **kept alive** across the full forward pass -(not ping-ponged away). DEC1 and DEC0 read them directly. +Skip connections: enc0_tex (8ch) and enc1_lo/enc1_hi (16ch split) are **kept alive** across the +full forward pass (not ping-ponged away). DEC1 and DEC0 read them directly. --- @@ -856,7 +859,7 @@ python3 -m http.server 8000 Ordered for parallel execution where possible. Phases 1 and 2 are independent. -**Architecture locked:** enc_channels = [4, 8]. See Size Budget for weight counts. +**Architecture:** enc_channels = [8, 16]. See Size Budget for weight counts. --- @@ -881,7 +884,7 @@ before the real G-buffer exists. Wire real G-buffer in Phase 5. **1a. PyTorch model** - [ ] `cnn_v3/training/train_cnn_v3.py` - - [ ] `CNNv3` class: U-Net [4,8], FiLM MLP (5→16→48), channel dropout + - [ ] `CNNv3` class: U-Net [8,16], FiLM MLP (5→16→72), channel dropout - [ ] `GBufferDataset`: loads 20-channel feature tensors from packed PNGs - [ ] Training loop, checkpointing, grayscale/RGBA loss option @@ -919,11 +922,11 @@ no batch norm at inference, `#include` existing snippets where possible. - writes feat_tex0 (f16×8) + feat_tex1 (u8×12, spare) **2b. U-Net compute shaders** -- [ ] `src/effects/cnn_v3_enc0.wgsl` — Conv(20→4, 3×3) + FiLM + ReLU -- [ ] `src/effects/cnn_v3_enc1.wgsl` — Conv(4→8, 3×3) + FiLM + ReLU + avg_pool 2×2 -- [ ] `src/effects/cnn_v3_bottleneck.wgsl` — Conv(8→8, 1×1) + FiLM + ReLU -- [ ] `src/effects/cnn_v3_dec1.wgsl` — nearest upsample×2 + concat enc1_skip + Conv(16→4, 3×3) + FiLM + ReLU -- [ ] `src/effects/cnn_v3_dec0.wgsl` — nearest upsample×2 + concat enc0_skip + Conv(8→4, 3×3) + FiLM + ReLU +- [ ] `src/effects/cnn_v3_enc0.wgsl` — Conv(20→8, 3×3) + FiLM + ReLU +- [ ] `src/effects/cnn_v3_enc1.wgsl` — Conv(8→16, 3×3) + FiLM + ReLU + avg_pool 2×2 +- [ ] `src/effects/cnn_v3_bottleneck.wgsl` — Conv(16→16, 3×3, dilation=2) + ReLU +- [ ] `src/effects/cnn_v3_dec1.wgsl` — nearest upsample×2 + concat enc1_skip + Conv(32→8, 3×3) + FiLM + ReLU +- [ ] `src/effects/cnn_v3_dec0.wgsl` — nearest upsample×2 + concat enc0_skip + Conv(16→4, 3×3) + FiLM + ReLU - [ ] `src/effects/cnn_v3_output.wgsl` — Conv(4→4, 1×1) + sigmoid → composite to framebuffer Reuse from existing shaders: @@ -941,7 +944,7 @@ Reuse from existing shaders: - [ ] `src/effects/cnn_v3_effect.h` — class declaration - textures: feat_tex0, feat_tex1, enc0_tex, enc1_tex (half-res), bn_tex (half-res), dec1_tex, dec0_tex - **`WGPUTexture prev_cnn_tex_`** — persistent RGBA8, owned by effect, initialized black - - `FilmParams` uniform buffer (γ/β for 4 levels = 48 floats = 192 bytes) + - `FilmParams` uniform buffer (γ/β for 4 levels = 72 floats = 288 bytes) - FiLM MLP weights (loaded from .bin, run CPU-side per frame) - [ ] `src/effects/cnn_v3_effect.cc` — implementation -- cgit v1.2.3