diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-27 07:59:00 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-27 07:59:00 +0100 |
| commit | fb13e67acbc7d7dd2974a456fcb134966c47cee0 (patch) | |
| tree | 8dd1c6df371b0ee046792680a14c8bcb3c36510b /cnn_v3/training | |
| parent | 8c5e41724fdfc3be24e95f48ae4b2be616404074 (diff) | |
fix(cnn_v3): remove dec0 ReLU, load FiLM MLP at runtime
Two bugs blocking training convergence:
1. dec0 ReLU before sigmoid constrained output to [0.5,1.0] — network
could never produce dark pixels. Removed F.relu in train_cnn_v3.py
and max(0,…) in cnn_v3_dec0.wgsl. Test vectors regenerated.
2. set_film_params() used hardcoded heuristics instead of the trained MLP.
Added CNNv3FilmMlp struct + load_film_mlp() to cnn_v3_effect.h/.cc.
MLP auto-loaded from ASSET_WEIGHTS_CNN_V3_FILM_MLP at construction;
Linear(5→16)→ReLU→Linear(16→72) runs CPU-side each frame.
36/36 tests pass. Parity max_err=4.88e-4 unchanged.
handoff(Gemini): retrain from scratch — needs ≥50 samples (currently 11).
See cnn_v3/docs/HOWTO.md §2-3.
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 8 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 7 |
2 files changed, 7 insertions, 8 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 3f81247..96b175a 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -193,8 +193,8 @@ def dec1_forward(bn, enc1, w, gamma, beta): def dec0_forward(dec1, enc0, w, gamma, beta): """ - NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid - → rgba16float (full-res, final output). + NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + sigmoid + → rgba16float (full-res, final output). No ReLU before sigmoid. dec1: (hH, hW, 8) f32 — half-res enc0: (H, W, 8) f32 — full-res enc0 skip """ @@ -219,8 +219,8 @@ def dec0_forward(dec1, enc0, w, gamma, beta): for kx in range(3): wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx) s += wv * fp[ky:ky+H, kx:kx+W, i] - # FiLM + ReLU + sigmoid (matches WGSL dec0 shader) - v = np.maximum(0.0, gamma[o] * s + beta[o]) + # FiLM + sigmoid (matches WGSL dec0 shader — no ReLU before sigmoid) + v = gamma[o] * s + beta[o] out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32) return np.float16(out).astype(np.float32) # rgba16float boundary diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index e48f684..fa0d2e2 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -10,8 +10,7 @@ Architecture (enc_channels=[8,16]): enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split) bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split) dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch) - dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch) - output sigmoid → RGBA + dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM + sigmoid H×W rgba16float (4ch) FiLM MLP: Linear(5→16) → ReLU → Linear(16→72) 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4) @@ -93,9 +92,9 @@ class CNNv3(nn.Module): torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) ), gd1, bd1)) - x = F.relu(film_apply(self.dec0( + x = film_apply(self.dec0( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) - ), gd0, bd0)) + ), gd0, bd0) return torch.sigmoid(x) |
