summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.h
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.h')
-rw-r--r--cnn_v3/src/cnn_v3_effect.h132
1 files changed, 132 insertions, 0 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h
new file mode 100644
index 0000000..c358990
--- /dev/null
+++ b/cnn_v3/src/cnn_v3_effect.h
@@ -0,0 +1,132 @@
+// CNN v3 Effect — U-Net + FiLM inference pass
+// Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature
+// textures produced by GBufferEffect.
+//
+// Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer)
+// Output: output_tex (rgba16float, 4-channel RGBA)
+
+#pragma once
+
+#include <cstdint>
+
+#include "gpu/effect.h"
+#include "gpu/sequence.h"
+#include "gpu/uniform_helper.h"
+#include "gpu/wgpu_resource.h"
+
+// ---------------------------------------------------------------------------
+// Per-pass params uniform layouts (mirror WGSL Params structs exactly)
+// ---------------------------------------------------------------------------
+
+// enc0, dec1, dec0: 4-channel FiLM
+//
+// WGSL layout (vec3u has align=16, so _pad sits at offset 16):
+// offset 0: weight_offset (u32, 4 bytes)
+// offset 4: (12 bytes implicit padding before vec3u)
+// offset 16: _pad (vec3u, 12 bytes)
+// offset 28: (4 bytes implicit padding before vec4f)
+// offset 32: gamma (vec4f, 16 bytes)
+// offset 48: beta (vec4f, 16 bytes)
+// total: 64 bytes
+struct CnnV3Params4ch {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad)
+ float gamma[4]; // offset 32
+ float beta[4]; // offset 48
+};
+static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");
+
+// enc1: 8-channel FiLM (split into lo/hi vec4 pairs)
+//
+// WGSL layout (same header padding as above):
+// offset 0: weight_offset (u32, 4 bytes)
+// offset 16: _pad (vec3u, 12 bytes)
+// offset 32: gamma_lo (vec4f, 16 bytes)
+// offset 48: gamma_hi (vec4f, 16 bytes)
+// offset 64: beta_lo (vec4f, 16 bytes)
+// offset 80: beta_hi (vec4f, 16 bytes)
+// total: 96 bytes
+struct CnnV3ParamsEnc1 {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31
+ float gamma_lo[4]; // offset 32
+ float gamma_hi[4]; // offset 48
+ float beta_lo[4]; // offset 64
+ float beta_hi[4]; // offset 80
+};
+static_assert(sizeof(CnnV3ParamsEnc1) == 96,
+ "CnnV3ParamsEnc1 must be 96 bytes");
+
+// bottleneck: no FiLM — 4 plain u32s, no alignment gap
+struct CnnV3ParamsBn {
+ uint32_t weight_offset;
+ uint32_t _pad[3];
+};
+static_assert(sizeof(CnnV3ParamsBn) == 16, "CnnV3ParamsBn must be 16 bytes");
+
+// ---------------------------------------------------------------------------
+// FiLM conditioning inputs (CPU-side, uploaded via set_film_params each frame)
+// ---------------------------------------------------------------------------
+struct CNNv3FiLMParams {
+ float beat_phase = 0.0f; // 0-1 within current beat
+ float beat_norm = 0.0f; // beat_time / 8.0, normalized 8-beat cycle
+ float audio_intensity = 0.0f; // peak audio level 0-1
+ float style_p0 = 0.0f; // user-defined style param
+ float style_p1 = 0.0f; // user-defined style param
+};
+
+class CNNv3Effect : public Effect {
+ public:
+ CNNv3Effect(const GpuContext& ctx, const std::vector<std::string>& inputs,
+ const std::vector<std::string>& outputs, float start_time,
+ float end_time);
+
+ void declare_nodes(NodeRegistry& registry) override;
+
+ void render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params,
+ NodeRegistry& nodes) override;
+
+ // Update FiLM conditioning; call before render() each frame.
+ void set_film_params(const CNNv3FiLMParams& fp);
+
+ private:
+ // Intermediate node names (prefixed from output[0])
+ std::string node_enc0_;
+ std::string node_enc1_;
+ std::string node_bottleneck_;
+ std::string node_dec1_;
+
+ // 5 compute pipelines
+ ComputePipeline enc0_pipeline_;
+ ComputePipeline enc1_pipeline_;
+ ComputePipeline bn_pipeline_;
+ ComputePipeline dec1_pipeline_;
+ ComputePipeline dec0_pipeline_;
+
+ // 5 bind groups (rebuilt each render since node views may change)
+ BindGroup enc0_bg_;
+ BindGroup enc1_bg_;
+ BindGroup bn_bg_;
+ BindGroup dec1_bg_;
+ BindGroup dec0_bg_;
+
+ // Params uniform buffers (one per pass)
+ UniformBuffer<CnnV3Params4ch> enc0_params_buf_;
+ UniformBuffer<CnnV3ParamsEnc1> enc1_params_buf_;
+ UniformBuffer<CnnV3ParamsBn> bn_params_buf_;
+ UniformBuffer<CnnV3Params4ch> dec1_params_buf_;
+ UniformBuffer<CnnV3Params4ch> dec0_params_buf_;
+
+ // Shared packed-f16 weights (storage buffer, read-only in all shaders)
+ GpuBuffer weights_buf_;
+
+ // Per-pass params shadow (updated by set_film_params, uploaded in render)
+ CnnV3Params4ch enc0_params_{};
+ CnnV3ParamsEnc1 enc1_params_{};
+ CnnV3ParamsBn bn_params_{};
+ CnnV3Params4ch dec1_params_{};
+ CnnV3Params4ch dec0_params_{};
+
+ void create_pipelines();
+ void update_bind_groups(NodeRegistry& nodes);
+};