diff options
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.h')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.h | 61 |
1 files changed, 31 insertions, 30 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h index 589680c..ac0166f 100644 --- a/cnn_v3/src/cnn_v3_effect.h +++ b/cnn_v3/src/cnn_v3_effect.h @@ -38,12 +38,12 @@ // offset 80: beta_hi (vec4f) // total: 96 bytes struct CnnV3Params8ch { - uint32_t weight_offset; // offset 0 - uint32_t _pad[7]; // offsets 4-31 - float gamma_lo[4]; // offset 32 - float gamma_hi[4]; // offset 48 - float beta_lo[4]; // offset 64 - float beta_hi[4]; // offset 80 + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma_lo[4]; // offset 32 + float gamma_hi[4]; // offset 48 + float beta_lo[4]; // offset 64 + float beta_hi[4]; // offset 80 }; static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes"); @@ -56,12 +56,13 @@ static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes"); // offset 96: beta_0..3 (4x vec4f = 64 bytes) // total: 160 bytes struct CnnV3Params16ch { - uint32_t weight_offset; // offset 0 - uint32_t _pad[7]; // offsets 4-31 - float gamma[16]; // offsets 32-95 - float beta[16]; // offsets 96-159 + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma[16]; // offsets 32-95 + float beta[16]; // offsets 96-159 }; -static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes"); +static_assert(sizeof(CnnV3Params16ch) == 160, + "CnnV3Params16ch must be 160 bytes"); // dec0: 4-channel FiLM // @@ -72,10 +73,10 @@ static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes // offset 48: beta (vec4f) // total: 64 bytes struct CnnV3Params4ch { - uint32_t weight_offset; // offset 0 - uint32_t _pad[7]; // offsets 4-31 - float gamma[4]; // offset 32 - float beta[4]; // offset 48 + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma[4]; // offset 32 + float beta[4]; // offset 48 }; static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes"); @@ -90,20 +91,20 @@ static_assert(sizeof(CnnV3ParamsBn) == 16, "CnnV3ParamsBn must be 16 bytes"); // FiLM conditioning inputs (CPU-side, uploaded via set_film_params each frame) // --------------------------------------------------------------------------- struct CNNv3FiLMParams { - float beat_phase = 0.0f; // 0-1 within current beat - float beat_norm = 0.0f; // beat_time / 8.0, normalized 8-beat cycle - float audio_intensity = 0.0f; // peak audio level 0-1 - float style_p0 = 0.0f; // user-defined style param - float style_p1 = 0.0f; // user-defined style param + float beat_phase = 0.0f; // 0-1 within current beat + float beat_norm = 0.0f; // beat_time / 8.0, normalized 8-beat cycle + float audio_intensity = 0.0f; // peak audio level 0-1 + float style_p0 = 0.0f; // user-defined style param + float style_p1 = 0.0f; // user-defined style param }; // FiLM MLP weights: Linear(5→16)→ReLU→Linear(16→72). // Loaded from cnn_v3_film_mlp.bin (1320 f32 = 5280 bytes). // Layout: l0_w(80) | l0_b(16) | l1_w(1152) | l1_b(72), all row-major f32. struct CNNv3FilmMlp { - float l0_w[16 * 5]; // (16, 5) row-major + float l0_w[16 * 5]; // (16, 5) row-major float l0_b[16]; - float l1_w[72 * 16]; // (72, 16) row-major + float l1_w[72 * 16]; // (72, 16) row-major float l1_b[72]; }; static_assert(sizeof(CNNv3FilmMlp) == 1320 * 4, "CNNv3FilmMlp size mismatch"); @@ -153,21 +154,21 @@ class CNNv3Effect : public Effect { BindGroup dec0_bg_; // Params uniform buffers (one per pass) - UniformBuffer<CnnV3Params8ch> enc0_params_buf_; + UniformBuffer<CnnV3Params8ch> enc0_params_buf_; UniformBuffer<CnnV3Params16ch> enc1_params_buf_; - UniformBuffer<CnnV3ParamsBn> bn_params_buf_; - UniformBuffer<CnnV3Params8ch> dec1_params_buf_; - UniformBuffer<CnnV3Params4ch> dec0_params_buf_; + UniformBuffer<CnnV3ParamsBn> bn_params_buf_; + UniformBuffer<CnnV3Params8ch> dec1_params_buf_; + UniformBuffer<CnnV3Params4ch> dec0_params_buf_; // Shared packed-f16 weights (storage buffer, read-only in all shaders) GpuBuffer weights_buf_; // Per-pass params shadow (updated by set_film_params, uploaded in render) - CnnV3Params8ch enc0_params_{}; + CnnV3Params8ch enc0_params_{}; CnnV3Params16ch enc1_params_{}; - CnnV3ParamsBn bn_params_{}; - CnnV3Params8ch dec1_params_{}; - CnnV3Params4ch dec0_params_{}; + CnnV3ParamsBn bn_params_{}; + CnnV3Params8ch dec1_params_{}; + CnnV3Params4ch dec0_params_{}; void create_pipelines(); void update_bind_groups(NodeRegistry& nodes); |
