diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-27 07:59:00 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-27 07:59:00 +0100 |
| commit | fb13e67acbc7d7dd2974a456fcb134966c47cee0 (patch) | |
| tree | 8dd1c6df371b0ee046792680a14c8bcb3c36510b /cnn_v3/src/cnn_v3_effect.cc | |
| parent | 8c5e41724fdfc3be24e95f48ae4b2be616404074 (diff) | |
fix(cnn_v3): remove dec0 ReLU, load FiLM MLP at runtime
Two bugs blocking training convergence:
1. dec0 ReLU before sigmoid constrained output to [0.5,1.0] — network
could never produce dark pixels. Removed F.relu in train_cnn_v3.py
and max(0,…) in cnn_v3_dec0.wgsl. Test vectors regenerated.
2. set_film_params() used hardcoded heuristics instead of the trained MLP.
Added CNNv3FilmMlp struct + load_film_mlp() to cnn_v3_effect.h/.cc.
MLP auto-loaded from ASSET_WEIGHTS_CNN_V3_FILM_MLP at construction;
Linear(5→16)→ReLU→Linear(16→72) runs CPU-side each frame.
36/36 tests pass. Parity max_err=4.88e-4 unchanged.
handoff(Gemini): retrain from scratch — needs ≥50 samples (currently 11).
See cnn_v3/docs/HOWTO.md §2-3.
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 76 |
1 files changed, 61 insertions, 15 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index dc26751..e576ceb 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -187,6 +187,13 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, if (weights_data && weights_size == kWeightsBufBytes) { upload_weights(ctx_.queue, weights_data, (uint32_t)weights_size); } + + size_t mlp_size = 0; + const void* mlp_data = + GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3_FILM_MLP, &mlp_size); + if (mlp_data) { + load_film_mlp(mlp_data, (uint32_t)mlp_size); + } } // --------------------------------------------------------------------------- @@ -219,28 +226,67 @@ void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes); } +void CNNv3Effect::load_film_mlp(const void* data, uint32_t size_bytes) { + if (size_bytes != sizeof(CNNv3FilmMlp)) return; + memcpy(&mlp_, data, sizeof(CNNv3FilmMlp)); + mlp_loaded_ = true; +} + void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { - const float a = fp.audio_intensity; - const float b = fp.beat_phase; + if (!mlp_loaded_) { + // Identity FiLM (γ=1, β=0) — no learned conditioning available. + return; + } + + // cond[5] = {beat_phase, beat_norm, audio_intensity, style_p0, style_p1} + const float cond[5] = {fp.beat_phase, fp.beat_norm, fp.audio_intensity, + fp.style_p0, fp.style_p1}; + // Layer 0: Linear(5→16) + ReLU + float h[16]; + for (int j = 0; j < 16; ++j) { + float s = mlp_.l0_b[j]; + for (int i = 0; i < 5; ++i) s += mlp_.l0_w[j * 5 + i] * cond[i]; + h[j] = s > 0.f ? s : 0.f; + } + + // Layer 1: Linear(16→72) + // Output split: g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4) + float film[72]; + for (int j = 0; j < 72; ++j) { + float s = mlp_.l1_b[j]; + for (int i = 0; i < 16; ++i) s += mlp_.l1_w[j * 16 + i] * h[i]; + film[j] = s; + } + + const float* p = film; for (int i = 0; i < 4; ++i) { - enc0_params_.gamma_lo[i] = 1.0f + a * 0.5f; - enc0_params_.gamma_hi[i] = 1.0f + a * 0.5f; - enc0_params_.beta_lo[i] = b * 0.1f; - enc0_params_.beta_hi[i] = b * 0.1f; + enc0_params_.gamma_lo[i] = p[i]; + enc0_params_.gamma_hi[i] = p[i + 4]; } - for (int i = 0; i < 16; ++i) { - enc1_params_.gamma[i] = 1.0f + a * 0.3f; - enc1_params_.beta[i] = fp.beat_norm * 0.1f; + p += 8; + for (int i = 0; i < 4; ++i) { + enc0_params_.beta_lo[i] = p[i]; + enc0_params_.beta_hi[i] = p[i + 4]; + } + p += 8; + for (int i = 0; i < 16; ++i) enc1_params_.gamma[i] = p[i]; + p += 16; + for (int i = 0; i < 16; ++i) enc1_params_.beta[i] = p[i]; + p += 16; + for (int i = 0; i < 4; ++i) { + dec1_params_.gamma_lo[i] = p[i]; + dec1_params_.gamma_hi[i] = p[i + 4]; } + p += 8; for (int i = 0; i < 4; ++i) { - dec1_params_.gamma_lo[i] = 1.0f + fp.style_p0 * 0.5f; - dec1_params_.gamma_hi[i] = 1.0f + fp.style_p0 * 0.5f; - dec1_params_.beta_lo[i] = fp.style_p1 * 0.1f; - dec1_params_.beta_hi[i] = fp.style_p1 * 0.1f; - dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; - dec0_params_.beta[i] = fp.style_p1 * 0.1f; + dec1_params_.beta_lo[i] = p[i]; + dec1_params_.beta_hi[i] = p[i + 4]; } + p += 8; + for (int i = 0; i < 4; ++i) dec0_params_.gamma[i] = p[i]; + p += 4; + for (int i = 0; i < 4; ++i) dec0_params_.beta[i] = p[i]; } // --------------------------------------------------------------------------- |
