summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-27 07:59:00 +0100
committerskal <pascal.massimino@gmail.com>2026-03-27 07:59:00 +0100
commitfb13e67acbc7d7dd2974a456fcb134966c47cee0 (patch)
tree8dd1c6df371b0ee046792680a14c8bcb3c36510b /cnn_v3/training/train_cnn_v3.py
parent8c5e41724fdfc3be24e95f48ae4b2be616404074 (diff)
fix(cnn_v3): remove dec0 ReLU, load FiLM MLP at runtime
Two bugs blocking training convergence: 1. dec0 ReLU before sigmoid constrained output to [0.5,1.0] — network could never produce dark pixels. Removed F.relu in train_cnn_v3.py and max(0,…) in cnn_v3_dec0.wgsl. Test vectors regenerated. 2. set_film_params() used hardcoded heuristics instead of the trained MLP. Added CNNv3FilmMlp struct + load_film_mlp() to cnn_v3_effect.h/.cc. MLP auto-loaded from ASSET_WEIGHTS_CNN_V3_FILM_MLP at construction; Linear(5→16)→ReLU→Linear(16→72) runs CPU-side each frame. 36/36 tests pass. Parity max_err=4.88e-4 unchanged. handoff(Gemini): retrain from scratch — needs ≥50 samples (currently 11). See cnn_v3/docs/HOWTO.md §2-3.
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index e48f684..fa0d2e2 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -10,8 +10,7 @@ Architecture (enc_channels=[8,16]):
enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split)
bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split)
dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch)
- dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch)
- output sigmoid → RGBA
+ dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM + sigmoid H×W rgba16float (4ch)
FiLM MLP: Linear(5→16) → ReLU → Linear(16→72)
72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4)
@@ -93,9 +92,9 @@ class CNNv3(nn.Module):
torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
), gd1, bd1))
- x = F.relu(film_apply(self.dec0(
+ x = film_apply(self.dec0(
torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
- ), gd0, bd0))
+ ), gd0, bd0)
return torch.sigmoid(x)