// CNN v3 Effect — U-Net + FiLM inference pass // Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature // textures produced by GBufferEffect. // // Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer) // Output: output_tex (rgba16float, 4-channel RGBA) #pragma once #include #include "gpu/effect.h" #include "gpu/sequence.h" #include "gpu/uniform_helper.h" #include "gpu/wgpu_resource.h" // --------------------------------------------------------------------------- // Per-pass params uniform layouts (mirror WGSL Params structs exactly) // --------------------------------------------------------------------------- // enc0, dec1, dec0: 4-channel FiLM // // WGSL layout (vec3u has align=16, so _pad sits at offset 16): // offset 0: weight_offset (u32, 4 bytes) // offset 4: (12 bytes implicit padding before vec3u) // offset 16: _pad (vec3u, 12 bytes) // offset 28: (4 bytes implicit padding before vec4f) // offset 32: gamma (vec4f, 16 bytes) // offset 48: beta (vec4f, 16 bytes) // total: 64 bytes struct CnnV3Params4ch { uint32_t weight_offset; // offset 0 uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad) float gamma[4]; // offset 32 float beta[4]; // offset 48 }; static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes"); // enc1: 8-channel FiLM (split into lo/hi vec4 pairs) // // WGSL layout (same header padding as above): // offset 0: weight_offset (u32, 4 bytes) // offset 16: _pad (vec3u, 12 bytes) // offset 32: gamma_lo (vec4f, 16 bytes) // offset 48: gamma_hi (vec4f, 16 bytes) // offset 64: beta_lo (vec4f, 16 bytes) // offset 80: beta_hi (vec4f, 16 bytes) // total: 96 bytes struct CnnV3ParamsEnc1 { uint32_t weight_offset; // offset 0 uint32_t _pad[7]; // offsets 4-31 float gamma_lo[4]; // offset 32 float gamma_hi[4]; // offset 48 float beta_lo[4]; // offset 64 float beta_hi[4]; // offset 80 }; static_assert(sizeof(CnnV3ParamsEnc1) == 96, "CnnV3ParamsEnc1 must be 96 bytes"); // bottleneck: no FiLM — 4 plain u32s, no alignment gap struct CnnV3ParamsBn { uint32_t weight_offset; uint32_t _pad[3]; }; static_assert(sizeof(CnnV3ParamsBn) == 16, "CnnV3ParamsBn must be 16 bytes"); // --------------------------------------------------------------------------- // FiLM conditioning inputs (CPU-side, uploaded via set_film_params each frame) // --------------------------------------------------------------------------- struct CNNv3FiLMParams { float beat_phase = 0.0f; // 0-1 within current beat float beat_norm = 0.0f; // beat_time / 8.0, normalized 8-beat cycle float audio_intensity = 0.0f; // peak audio level 0-1 float style_p0 = 0.0f; // user-defined style param float style_p1 = 0.0f; // user-defined style param }; class CNNv3Effect : public Effect { public: CNNv3Effect(const GpuContext& ctx, const std::vector& inputs, const std::vector& outputs, float start_time, float end_time); void declare_nodes(NodeRegistry& registry) override; void render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params, NodeRegistry& nodes) override; // Update FiLM conditioning; call before render() each frame. void set_film_params(const CNNv3FiLMParams& fp); private: // Intermediate node names (prefixed from output[0]) std::string node_enc0_; std::string node_enc1_; std::string node_bottleneck_; std::string node_dec1_; // 5 compute pipelines ComputePipeline enc0_pipeline_; ComputePipeline enc1_pipeline_; ComputePipeline bn_pipeline_; ComputePipeline dec1_pipeline_; ComputePipeline dec0_pipeline_; // 5 bind groups (rebuilt each render since node views may change) BindGroup enc0_bg_; BindGroup enc1_bg_; BindGroup bn_bg_; BindGroup dec1_bg_; BindGroup dec0_bg_; // Params uniform buffers (one per pass) UniformBuffer enc0_params_buf_; UniformBuffer enc1_params_buf_; UniformBuffer bn_params_buf_; UniformBuffer dec1_params_buf_; UniformBuffer dec0_params_buf_; // Shared packed-f16 weights (storage buffer, read-only in all shaders) GpuBuffer weights_buf_; // Per-pass params shadow (updated by set_film_params, uploaded in render) CnnV3Params4ch enc0_params_{}; CnnV3ParamsEnc1 enc1_params_{}; CnnV3ParamsBn bn_params_{}; CnnV3Params4ch dec1_params_{}; CnnV3Params4ch dec0_params_{}; void create_pipelines(); void update_bind_groups(NodeRegistry& nodes); };