diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 08:38:29 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 08:38:29 +0100 |
| commit | a4ff60233fce134e8f779ef001872dfd9a8f9923 (patch) | |
| tree | 3a5466273ecb42269b4d6443c893c61b84ee7d93 /cnn_v3/docs | |
| parent | 4d055080d2ab4b674d5f0fd611ea051e87454a31 (diff) | |
feat(cnn_v3): Phase 3 complete — WGSL U-Net inference shaders
5 compute shaders + cnn_v3/common snippet:
enc0: Conv(20→4,3×3) + FiLM + ReLU full-res
enc1: AvgPool + Conv(4→8,3×3) + FiLM + ReLU half-res
bottleneck: AvgPool + Conv(8→8,1×1) + ReLU quarter-res
dec1: NearestUp + cat(enc1) + Conv(16→4) + FiLM half-res
dec0: NearestUp + cat(enc0) + Conv(8→4) + FiLM + Sigmoid full-res
Parity rules: zero-pad conv, AvgPool down, NearestUp, FiLM after
conv+bias, skip=concat, OIHW weights+bias layout. Matches PyTorch
train_cnn_v3.py forward() exactly.
Registered in workspaces/main/assets.txt + src/effects/shaders.cc.
Weight layout + Params struct documented in cnn_v3/docs/HOWTO.md §7.
Next: Phase 4 — C++ CNNv3Effect + FiLM uniform upload.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3/docs')
| -rw-r--r-- | cnn_v3/docs/HOWTO.md | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index 88d4bbc..ad71f1f 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -200,13 +200,67 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++. | 1 — G-buffer (raster + pack) | ✅ Done | Integrated, 35/35 tests pass | | 1 — G-buffer (SDF + shadow passes) | TODO | Placeholder in place | | 2 — Training infrastructure | ✅ Done | blender_export.py, pack_*_sample.py | -| 3 — WGSL U-Net shaders | TODO | enc/dec/bottleneck/FiLM | +| 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet | | 4 — C++ CNNv3Effect | TODO | FiLM uniform upload | | 5 — Parity validation | TODO | Test vectors, ≤1/255 | --- -## 7. Quick Troubleshooting +## 7. CNN v3 Inference Shaders (Phase 3) + +Five compute passes, each a standalone WGSL shader using `#include "cnn_v3/common"`. +The common snippet provides `get_w()` and `unpack_8ch()`. + +| Pass | Shader | Input(s) | Output | Dims | +|------|--------|----------|--------|------| +| enc0 | `cnn_v3_enc0.wgsl` | feat_tex0+feat_tex1 (20ch) | enc0_tex rgba16float (4ch) | full | +| enc1 | `cnn_v3_enc1.wgsl` | enc0_tex (AvgPool2×2 inline) | enc1_tex rgba32uint (8ch) | ½ | +| bottleneck | `cnn_v3_bottleneck.wgsl` | enc1_tex (AvgPool2×2 inline) | bottleneck_tex rgba32uint (8ch) | ¼ | +| dec1 | `cnn_v3_dec1.wgsl` | bottleneck_tex + enc1_tex (skip) | dec1_tex rgba16float (4ch) | ½ | +| dec0 | `cnn_v3_dec0.wgsl` | dec1_tex + enc0_tex (skip) | output_tex rgba16float (4ch) | full | + +**Parity rules baked into the shaders:** +- Zero-padding (not clamp) at conv borders +- AvgPool 2×2 for downsampling (exact, deterministic) +- Nearest-neighbor for upsampling (integer `coord / 2`) +- Skip connections: channel concatenation (not add) +- FiLM applied after conv+bias, before ReLU: `max(0, γ·x + β)` +- No batch norm at inference +- Weight layout: OIHW (out × in × kH × kW), biases after conv weights + +**Params uniform per shader** (`group 0, binding 3`): +``` +struct Params { + weight_offset: u32, // f16 index into shared weights buffer + _pad: vec3u, + gamma: vec4f, // FiLM γ (enc1: gamma_lo+gamma_hi for 8ch) + beta: vec4f, // FiLM β (enc1: beta_lo+beta_hi for 8ch) +} +``` +FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each frame. + +**Weight offsets** (f16 units, including bias): +| Layer | Weights | Bias | Total f16 | +|-------|---------|------|-----------| +| enc0 | 20×4×9=720 | +4 | 724 | +| enc1 | 4×8×9=288 | +8 | 296 | +| bottleneck | 8×8×1=64 | +8 | 72 | +| dec1 | 16×4×9=576 | +4 | 580 | +| dec0 | 8×4×9=288 | +4 | 292 | +| **Total** | | | **2064 f16 = ~4 KB** | + +**Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`): +`SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`, +`SHADER_CNN_V3_BOTTLENECK`, `SHADER_CNN_V3_DEC1`, `SHADER_CNN_V3_DEC0` + +**C++ usage (Phase 4):** +```cpp +auto src = ShaderComposer::Get().Compose({"cnn_v3/common"}, raw_wgsl); +``` + +--- + +## 8. Quick Troubleshooting **GBufferEffect renders nothing / albedo is black** - Check `set_scene()` was called before `render()` @@ -227,7 +281,7 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++. --- -## See Also +## 9. See Also - `cnn_v3/docs/CNN_V3.md` — Full architecture design (U-Net, FiLM, feature layout) - `doc/EFFECT_WORKFLOW.md` — General effect integration guide |
