summaryrefslogtreecommitdiff
path: root/cnn_v3/training/gen_test_vectors.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-27 07:59:00 +0100
committerskal <pascal.massimino@gmail.com>2026-03-27 07:59:00 +0100
commitfb13e67acbc7d7dd2974a456fcb134966c47cee0 (patch)
tree8dd1c6df371b0ee046792680a14c8bcb3c36510b /cnn_v3/training/gen_test_vectors.py
parent8c5e41724fdfc3be24e95f48ae4b2be616404074 (diff)
fix(cnn_v3): remove dec0 ReLU, load FiLM MLP at runtime
Two bugs blocking training convergence: 1. dec0 ReLU before sigmoid constrained output to [0.5,1.0] — network could never produce dark pixels. Removed F.relu in train_cnn_v3.py and max(0,…) in cnn_v3_dec0.wgsl. Test vectors regenerated. 2. set_film_params() used hardcoded heuristics instead of the trained MLP. Added CNNv3FilmMlp struct + load_film_mlp() to cnn_v3_effect.h/.cc. MLP auto-loaded from ASSET_WEIGHTS_CNN_V3_FILM_MLP at construction; Linear(5→16)→ReLU→Linear(16→72) runs CPU-side each frame. 36/36 tests pass. Parity max_err=4.88e-4 unchanged. handoff(Gemini): retrain from scratch — needs ≥50 samples (currently 11). See cnn_v3/docs/HOWTO.md §2-3.
Diffstat (limited to 'cnn_v3/training/gen_test_vectors.py')
-rw-r--r--cnn_v3/training/gen_test_vectors.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 3f81247..96b175a 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -193,8 +193,8 @@ def dec1_forward(bn, enc1, w, gamma, beta):
def dec0_forward(dec1, enc0, w, gamma, beta):
"""
- NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
- → rgba16float (full-res, final output).
+ NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + sigmoid
+ → rgba16float (full-res, final output). No ReLU before sigmoid.
dec1: (hH, hW, 8) f32 — half-res
enc0: (H, W, 8) f32 — full-res enc0 skip
"""
@@ -219,8 +219,8 @@ def dec0_forward(dec1, enc0, w, gamma, beta):
for kx in range(3):
wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx)
s += wv * fp[ky:ky+H, kx:kx+W, i]
- # FiLM + ReLU + sigmoid (matches WGSL dec0 shader)
- v = np.maximum(0.0, gamma[o] * s + beta[o])
+ # FiLM + sigmoid (matches WGSL dec0 shader — no ReLU before sigmoid)
+ v = gamma[o] * s + beta[o]
out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32)
return np.float16(out).astype(np.float32) # rgba16float boundary