summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc76
1 files changed, 61 insertions, 15 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index dc26751..e576ceb 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -187,6 +187,13 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
if (weights_data && weights_size == kWeightsBufBytes) {
upload_weights(ctx_.queue, weights_data, (uint32_t)weights_size);
}
+
+ size_t mlp_size = 0;
+ const void* mlp_data =
+ GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3_FILM_MLP, &mlp_size);
+ if (mlp_data) {
+ load_film_mlp(mlp_data, (uint32_t)mlp_size);
+ }
}
// ---------------------------------------------------------------------------
@@ -219,28 +226,67 @@ void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes);
}
+void CNNv3Effect::load_film_mlp(const void* data, uint32_t size_bytes) {
+ if (size_bytes != sizeof(CNNv3FilmMlp)) return;
+ memcpy(&mlp_, data, sizeof(CNNv3FilmMlp));
+ mlp_loaded_ = true;
+}
+
void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
- const float a = fp.audio_intensity;
- const float b = fp.beat_phase;
+ if (!mlp_loaded_) {
+ // Identity FiLM (γ=1, β=0) — no learned conditioning available.
+ return;
+ }
+
+ // cond[5] = {beat_phase, beat_norm, audio_intensity, style_p0, style_p1}
+ const float cond[5] = {fp.beat_phase, fp.beat_norm, fp.audio_intensity,
+ fp.style_p0, fp.style_p1};
+ // Layer 0: Linear(5→16) + ReLU
+ float h[16];
+ for (int j = 0; j < 16; ++j) {
+ float s = mlp_.l0_b[j];
+ for (int i = 0; i < 5; ++i) s += mlp_.l0_w[j * 5 + i] * cond[i];
+ h[j] = s > 0.f ? s : 0.f;
+ }
+
+ // Layer 1: Linear(16→72)
+ // Output split: g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4)
+ float film[72];
+ for (int j = 0; j < 72; ++j) {
+ float s = mlp_.l1_b[j];
+ for (int i = 0; i < 16; ++i) s += mlp_.l1_w[j * 16 + i] * h[i];
+ film[j] = s;
+ }
+
+ const float* p = film;
for (int i = 0; i < 4; ++i) {
- enc0_params_.gamma_lo[i] = 1.0f + a * 0.5f;
- enc0_params_.gamma_hi[i] = 1.0f + a * 0.5f;
- enc0_params_.beta_lo[i] = b * 0.1f;
- enc0_params_.beta_hi[i] = b * 0.1f;
+ enc0_params_.gamma_lo[i] = p[i];
+ enc0_params_.gamma_hi[i] = p[i + 4];
}
- for (int i = 0; i < 16; ++i) {
- enc1_params_.gamma[i] = 1.0f + a * 0.3f;
- enc1_params_.beta[i] = fp.beat_norm * 0.1f;
+ p += 8;
+ for (int i = 0; i < 4; ++i) {
+ enc0_params_.beta_lo[i] = p[i];
+ enc0_params_.beta_hi[i] = p[i + 4];
+ }
+ p += 8;
+ for (int i = 0; i < 16; ++i) enc1_params_.gamma[i] = p[i];
+ p += 16;
+ for (int i = 0; i < 16; ++i) enc1_params_.beta[i] = p[i];
+ p += 16;
+ for (int i = 0; i < 4; ++i) {
+ dec1_params_.gamma_lo[i] = p[i];
+ dec1_params_.gamma_hi[i] = p[i + 4];
}
+ p += 8;
for (int i = 0; i < 4; ++i) {
- dec1_params_.gamma_lo[i] = 1.0f + fp.style_p0 * 0.5f;
- dec1_params_.gamma_hi[i] = 1.0f + fp.style_p0 * 0.5f;
- dec1_params_.beta_lo[i] = fp.style_p1 * 0.1f;
- dec1_params_.beta_hi[i] = fp.style_p1 * 0.1f;
- dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f;
- dec0_params_.beta[i] = fp.style_p1 * 0.1f;
+ dec1_params_.beta_lo[i] = p[i];
+ dec1_params_.beta_hi[i] = p[i + 4];
}
+ p += 8;
+ for (int i = 0; i < 4; ++i) dec0_params_.gamma[i] = p[i];
+ p += 4;
+ for (int i = 0; i < 4; ++i) dec0_params_.beta[i] = p[i];
}
// ---------------------------------------------------------------------------