diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 08:52:53 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 08:52:53 +0100 |
| commit | fe008df92f7a68d81c9bedb4328da7001e0775f0 (patch) | |
| tree | 2c0182ef4df3b682ee5aa3ab22dcf3e2af08a4ed /cnn_v3 | |
| parent | a4ff60233fce134e8f779ef001872dfd9a8f9923 (diff) | |
feat(cnn_v3): Phase 4 complete — CNNv3Effect C++ + FiLM uniform upload
- cnn_v3/src/cnn_v3_effect.{h,cc}: full Effect subclass with 5 compute
passes (enc0→enc1→bottleneck→dec1→dec0), shared weights storage buffer,
per-pass uniform buffers, set_film_params() API
- Fixed WGSL/C++ struct alignment: vec3u has align=16, so CnnV3Params4ch
is 64 bytes and CnnV3ParamsEnc1 is 96 bytes (not 48/80)
- Weight offsets computed as explicit formulas (e.g. 20*4*9+4) for clarity
- Registered in CMake, shaders.h/cc, demo_effects.h, test_demo_effects.cc
- 35/35 tests pass
handoff(Gemini): CNN v3 Phase 5 next — parity validation (Python ref vs WGSL)
Diffstat (limited to 'cnn_v3')
| -rw-r--r-- | cnn_v3/docs/HOWTO.md | 2 | ||||
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 467 | ||||
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.h | 132 |
3 files changed, 600 insertions, 1 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index ad71f1f..22266d3 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -201,7 +201,7 @@ The CNN v3 design requires exact parity between PyTorch, WGSL (HTML), and C++. | 1 — G-buffer (SDF + shadow passes) | TODO | Placeholder in place | | 2 — Training infrastructure | ✅ Done | blender_export.py, pack_*_sample.py | | 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet | -| 4 — C++ CNNv3Effect | TODO | FiLM uniform upload | +| 4 — C++ CNNv3Effect | ✅ Done | FiLM uniform upload, 35/35 tests pass | | 5 — Parity validation | TODO | Test vectors, ≤1/255 | --- diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc new file mode 100644 index 0000000..d13799c --- /dev/null +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -0,0 +1,467 @@ +// CNN v3 Effect — U-Net + FiLM inference (5 compute passes) +// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md §7 for shader details. + +#include "cnn_v3_effect.h" +#include "gpu/gpu.h" +#include "gpu/shader_composer.h" +#include "util/fatal_error.h" +#include <cstdint> +#include <cstring> + +// --------------------------------------------------------------------------- +// Weight layout constants — explicit formulas matching WGSL shader comments +// --------------------------------------------------------------------------- +// +// Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases +// Layout: OIHW order (out × in × kH × kW), biases appended after conv weights +// +static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias +static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias +static const uint32_t kBnWeights = 8 * 8 * 1 + 8; // Conv(8→8,1×1)+bias +static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias +static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias + +static const uint32_t kEnc0Offset = 0; +static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; +static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; +static const uint32_t kDec1Offset = kBnOffset + kBnWeights; +static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; +static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; + +// Weights buffer size in bytes: f16 values are packed two-per-u32. +// Round up to u32 boundary. +static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4; + +// --------------------------------------------------------------------------- +// Shader source externs (registered in shaders.cc via InitShaderComposer) +// --------------------------------------------------------------------------- +extern const char* cnn_v3_enc0_wgsl; +extern const char* cnn_v3_enc1_wgsl; +extern const char* cnn_v3_bottleneck_wgsl; +extern const char* cnn_v3_dec1_wgsl; +extern const char* cnn_v3_dec0_wgsl; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +static WGPUShaderModule make_shader(WGPUDevice device, const char* wgsl) { + const std::string composed = + ShaderComposer::Get().Compose({"cnn_v3/common"}, wgsl); + + WGPUShaderSourceWGSL src = {}; + src.chain.sType = WGPUSType_ShaderSourceWGSL; + src.code = str_view(composed.c_str()); + + WGPUShaderModuleDescriptor desc = {}; + desc.nextInChain = &src.chain; + return wgpuDeviceCreateShaderModule(device, &desc); +} + +static WGPUBindGroupLayout make_bgl(WGPUDevice device, + const WGPUBindGroupLayoutEntry* entries, + uint32_t count) { + WGPUBindGroupLayoutDescriptor desc = {}; + desc.entryCount = count; + desc.entries = entries; + return wgpuDeviceCreateBindGroupLayout(device, &desc); +} + +static WGPUComputePipeline make_compute_pipeline(WGPUDevice device, + WGPUShaderModule shader, + const char* entry, + WGPUBindGroupLayout bgl) { + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUComputePipelineDescriptor pipe_desc = {}; + pipe_desc.layout = pl; + pipe_desc.compute.module = shader; + pipe_desc.compute.entryPoint = str_view(entry); + WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &pipe_desc); + + wgpuPipelineLayoutRelease(pl); + return pipe; +} + +// BGL entry helpers +static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) { + WGPUBindGroupLayoutEntry e = {}; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.texture.sampleType = WGPUTextureSampleType_Uint; + e.texture.viewDimension = WGPUTextureViewDimension_2D; + return e; +} +static WGPUBindGroupLayoutEntry bgl_float_tex(uint32_t binding) { + WGPUBindGroupLayoutEntry e = {}; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.texture.sampleType = WGPUTextureSampleType_Float; + e.texture.viewDimension = WGPUTextureViewDimension_2D; + return e; +} +static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) { + WGPUBindGroupLayoutEntry e = {}; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + return e; +} +static WGPUBindGroupLayoutEntry bgl_uniform_buf(uint32_t binding, + uint64_t min_size) { + WGPUBindGroupLayoutEntry e = {}; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.buffer.type = WGPUBufferBindingType_Uniform; + e.buffer.minBindingSize = min_size; + return e; +} +static WGPUBindGroupLayoutEntry bgl_storage_tex_write( + uint32_t binding, WGPUTextureFormat fmt) { + WGPUBindGroupLayoutEntry e = {}; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly; + e.storageTexture.format = fmt; + e.storageTexture.viewDimension = WGPUTextureViewDimension_2D; + return e; +} + +// --------------------------------------------------------------------------- +// Constructor +// --------------------------------------------------------------------------- + +CNNv3Effect::CNNv3Effect(const GpuContext& ctx, + const std::vector<std::string>& inputs, + const std::vector<std::string>& outputs, + float start_time, float end_time) + : Effect(ctx, inputs, outputs, start_time, end_time) { + HEADLESS_RETURN_IF_NULL(ctx_.device); + + const std::string& prefix = + outputs.empty() ? std::string("cnn_v3") : outputs[0]; + node_enc0_ = prefix + "_enc0"; + node_enc1_ = prefix + "_enc1"; + node_bottleneck_ = prefix + "_bottleneck"; + node_dec1_ = prefix + "_dec1"; + + // Allocate zeroed weights buffer (f16 pairs packed as u32). + // Weights are zero-initialized; load_weights() can fill from file later. + weights_buf_ = gpu_create_buffer( + ctx_.device, kWeightsBufBytes, + WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); + + // Initialize uniform buffers. + enc0_params_buf_.init(ctx_.device); + enc1_params_buf_.init(ctx_.device); + bn_params_buf_.init(ctx_.device); + dec1_params_buf_.init(ctx_.device); + dec0_params_buf_.init(ctx_.device); + + // Set weight offsets (FiLM γ/β default to identity: γ=1, β=0). + enc0_params_.weight_offset = kEnc0Offset; + for (int i = 0; i < 4; ++i) { enc0_params_.gamma[i] = 1.0f; } + + enc1_params_.weight_offset = kEnc1Offset; + for (int i = 0; i < 4; ++i) { + enc1_params_.gamma_lo[i] = 1.0f; + enc1_params_.gamma_hi[i] = 1.0f; + } + + bn_params_.weight_offset = kBnOffset; + + dec1_params_.weight_offset = kDec1Offset; + for (int i = 0; i < 4; ++i) { dec1_params_.gamma[i] = 1.0f; } + + dec0_params_.weight_offset = kDec0Offset; + for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; } + + create_pipelines(); +} + +// --------------------------------------------------------------------------- +// declare_nodes +// --------------------------------------------------------------------------- + +void CNNv3Effect::declare_nodes(NodeRegistry& registry) { + // enc0_tex: rgba16float full-res + registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, -1, -1); + // enc1_tex: rgba32uint half-res + registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, -1, -1); + // bottleneck_tex: rgba32uint quarter-res — declare at 1/4 resolution + registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, -1, -1); + // dec1_tex: rgba16float half-res + registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, -1, -1); + // output_tex: rgba16float full-res (the declared output_nodes_[0]) +} + +// --------------------------------------------------------------------------- +// set_film_params — simple linear mapping, no MLP yet +// --------------------------------------------------------------------------- + +void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { + // Identity + audio/beat modulation. + // Replace with FiLM MLP output once training is done. + const float a = fp.audio_intensity; + const float b = fp.beat_phase; + + for (int i = 0; i < 4; ++i) { + enc0_params_.gamma[i] = 1.0f + a * 0.5f; + enc0_params_.beta[i] = b * 0.1f; + } + for (int i = 0; i < 4; ++i) { + enc1_params_.gamma_lo[i] = 1.0f + a * 0.3f; + enc1_params_.gamma_hi[i] = 1.0f + a * 0.3f; + enc1_params_.beta_lo[i] = fp.beat_norm * 0.1f; + enc1_params_.beta_hi[i] = fp.beat_norm * 0.1f; + } + for (int i = 0; i < 4; ++i) { + dec1_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; + dec1_params_.beta[i] = fp.style_p1 * 0.1f; + dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; + dec0_params_.beta[i] = fp.style_p1 * 0.1f; + } +} + +// --------------------------------------------------------------------------- +// render +// --------------------------------------------------------------------------- + +void CNNv3Effect::render(WGPUCommandEncoder encoder, + const UniformsSequenceParams& params, + NodeRegistry& nodes) { + // Upload params uniforms. + enc0_params_buf_.update(ctx_.queue, enc0_params_); + enc1_params_buf_.update(ctx_.queue, enc1_params_); + bn_params_buf_.update(ctx_.queue, bn_params_); + dec1_params_buf_.update(ctx_.queue, dec1_params_); + dec0_params_buf_.update(ctx_.queue, dec0_params_); + + update_bind_groups(nodes); + + const int W = (int)params.resolution.x; + const int H = (int)params.resolution.y; + + // Dispatch helper: ceil(dim / 8) workgroups + auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, + int w, int h) { + WGPUComputePassDescriptor pass_desc = {}; + WGPUComputePassEncoder pass = + wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); + wgpuComputePassEncoderSetPipeline(pass, pipe); + wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr); + wgpuComputePassEncoderDispatchWorkgroups( + pass, + (uint32_t)((w + 7) / 8), + (uint32_t)((h + 7) / 8), + 1); + wgpuComputePassEncoderEnd(pass); + wgpuComputePassEncoderRelease(pass); + }; + + dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H); + dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2); + dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4); + dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2); + dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H); +} + +// --------------------------------------------------------------------------- +// create_pipelines +// --------------------------------------------------------------------------- + +void CNNv3Effect::create_pipelines() { + HEADLESS_RETURN_IF_NULL(ctx_.device); + WGPUDevice dev = ctx_.device; + + // --- enc0 --- + // B0: feat_tex0 (u32), B1: feat_tex1 (u32), B2: weights (storage), + // B3: params (uniform), B4: enc0_out (storage_tex rgba16float write) + { + WGPUBindGroupLayoutEntry e[5] = { + bgl_uint_tex(0), + bgl_uint_tex(1), + bgl_storage_buf(2), + bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), + }; + WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); + WGPUShaderModule sh = make_shader(dev, cnn_v3_enc0_wgsl); + enc0_pipeline_.set(make_compute_pipeline(dev, sh, "enc0_main", bgl)); + wgpuShaderModuleRelease(sh); + wgpuBindGroupLayoutRelease(bgl); + } + + // --- enc1 --- + // B0: enc0_tex (f32), B1: weights (storage), + // B2: params (uniform), B3: enc1_out (storage_tex rgba32uint write) + { + WGPUBindGroupLayoutEntry e[4] = { + bgl_float_tex(0), + bgl_storage_buf(1), + bgl_uniform_buf(2, sizeof(CnnV3ParamsEnc1)), + bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), + }; + WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); + WGPUShaderModule sh = make_shader(dev, cnn_v3_enc1_wgsl); + enc1_pipeline_.set(make_compute_pipeline(dev, sh, "enc1_main", bgl)); + wgpuShaderModuleRelease(sh); + wgpuBindGroupLayoutRelease(bgl); + } + + // --- bottleneck --- + // B0: enc1_tex (u32), B1: weights (storage), + // B2: params (uniform), B3: bottleneck_out (storage_tex rgba32uint write) + { + WGPUBindGroupLayoutEntry e[4] = { + bgl_uint_tex(0), + bgl_storage_buf(1), + bgl_uniform_buf(2, sizeof(CnnV3ParamsBn)), + bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), + }; + WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); + WGPUShaderModule sh = make_shader(dev, cnn_v3_bottleneck_wgsl); + bn_pipeline_.set(make_compute_pipeline(dev, sh, "bottleneck_main", bgl)); + wgpuShaderModuleRelease(sh); + wgpuBindGroupLayoutRelease(bgl); + } + + // --- dec1 --- + // B0: bottleneck_tex (u32), B1: enc1_tex (u32), B2: weights (storage), + // B3: params (uniform), B4: dec1_out (storage_tex rgba16float write) + { + WGPUBindGroupLayoutEntry e[5] = { + bgl_uint_tex(0), + bgl_uint_tex(1), + bgl_storage_buf(2), + bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), + }; + WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); + WGPUShaderModule sh = make_shader(dev, cnn_v3_dec1_wgsl); + dec1_pipeline_.set(make_compute_pipeline(dev, sh, "dec1_main", bgl)); + wgpuShaderModuleRelease(sh); + wgpuBindGroupLayoutRelease(bgl); + } + + // --- dec0 --- + // B0: dec1_tex (f32), B1: enc0_tex (f32), B2: weights (storage), + // B3: params (uniform), B4: output_tex (storage_tex rgba16float write) + { + WGPUBindGroupLayoutEntry e[5] = { + bgl_float_tex(0), + bgl_float_tex(1), + bgl_storage_buf(2), + bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), + }; + WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); + WGPUShaderModule sh = make_shader(dev, cnn_v3_dec0_wgsl); + dec0_pipeline_.set(make_compute_pipeline(dev, sh, "dec0_main", bgl)); + wgpuShaderModuleRelease(sh); + wgpuBindGroupLayoutRelease(bgl); + } +} + +// --------------------------------------------------------------------------- +// update_bind_groups — rebuilt each frame (node views may be recreated) +// --------------------------------------------------------------------------- + +// Helper: set a texture view binding entry. +static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding, + WGPUTextureView view) { + e = {}; + e.binding = binding; + e.textureView = view; +} +// Helper: set a buffer binding entry. +static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, + uint64_t size) { + e = {}; + e.binding = binding; + e.buffer = buf; + e.size = size; +} + +void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { + WGPUDevice dev = ctx_.device; + + WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); + WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); + WGPUTextureView enc0_view = nodes.get_view(node_enc0_); + WGPUTextureView enc1_view = nodes.get_view(node_enc1_); + WGPUTextureView bn_view = nodes.get_view(node_bottleneck_); + WGPUTextureView dec1_view = nodes.get_view(node_dec1_); + WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); + + WGPUBuffer wb = weights_buf_.buffer; + + auto make_bg = [&](WGPUComputePipeline pipe, WGPUBindGroupEntry* e, + uint32_t count) -> WGPUBindGroup { + WGPUBindGroupLayout bgl = + wgpuComputePipelineGetBindGroupLayout(pipe, 0); + WGPUBindGroupDescriptor desc = {}; + desc.layout = bgl; + desc.entryCount = count; + desc.entries = e; + WGPUBindGroup bg = wgpuDeviceCreateBindGroup(dev, &desc); + wgpuBindGroupLayoutRelease(bgl); + return bg; + }; + + // enc0: feat_tex0(B0), feat_tex1(B1), weights(B2), params(B3), enc0_out(B4) + { + WGPUBindGroupEntry e[5] = {}; + bg_tex(e[0], 0, feat0_view); + bg_tex(e[1], 1, feat1_view); + bg_buf(e[2], 2, wb, kWeightsBufBytes); + bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); + bg_tex(e[4], 4, enc0_view); + enc0_bg_.set(make_bg(enc0_pipeline_.get(), e, 5)); + } + + // enc1: enc0_tex(B0), weights(B1), params(B2), enc1_out(B3) + { + WGPUBindGroupEntry e[4] = {}; + bg_tex(e[0], 0, enc0_view); + bg_buf(e[1], 1, wb, kWeightsBufBytes); + bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3ParamsEnc1)); + bg_tex(e[3], 3, enc1_view); + enc1_bg_.set(make_bg(enc1_pipeline_.get(), e, 4)); + } + + // bottleneck: enc1_tex(B0), weights(B1), params(B2), bn_out(B3) + { + WGPUBindGroupEntry e[4] = {}; + bg_tex(e[0], 0, enc1_view); + bg_buf(e[1], 1, wb, kWeightsBufBytes); + bg_buf(e[2], 2, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn)); + bg_tex(e[3], 3, bn_view); + bn_bg_.set(make_bg(bn_pipeline_.get(), e, 4)); + } + + // dec1: bn_tex(B0), enc1_tex(B1), weights(B2), params(B3), dec1_out(B4) + { + WGPUBindGroupEntry e[5] = {}; + bg_tex(e[0], 0, bn_view); + bg_tex(e[1], 1, enc1_view); + bg_buf(e[2], 2, wb, kWeightsBufBytes); + bg_buf(e[3], 3, dec1_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); + bg_tex(e[4], 4, dec1_view); + dec1_bg_.set(make_bg(dec1_pipeline_.get(), e, 5)); + } + + // dec0: dec1_tex(B0), enc0_tex(B1), weights(B2), params(B3), output(B4) + { + WGPUBindGroupEntry e[5] = {}; + bg_tex(e[0], 0, dec1_view); + bg_tex(e[1], 1, enc0_view); + bg_buf(e[2], 2, wb, kWeightsBufBytes); + bg_buf(e[3], 3, dec0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); + bg_tex(e[4], 4, out_view); + dec0_bg_.set(make_bg(dec0_pipeline_.get(), e, 5)); + } +} diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h new file mode 100644 index 0000000..c358990 --- /dev/null +++ b/cnn_v3/src/cnn_v3_effect.h @@ -0,0 +1,132 @@ +// CNN v3 Effect — U-Net + FiLM inference pass +// Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature +// textures produced by GBufferEffect. +// +// Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer) +// Output: output_tex (rgba16float, 4-channel RGBA) + +#pragma once + +#include <cstdint> + +#include "gpu/effect.h" +#include "gpu/sequence.h" +#include "gpu/uniform_helper.h" +#include "gpu/wgpu_resource.h" + +// --------------------------------------------------------------------------- +// Per-pass params uniform layouts (mirror WGSL Params structs exactly) +// --------------------------------------------------------------------------- + +// enc0, dec1, dec0: 4-channel FiLM +// +// WGSL layout (vec3u has align=16, so _pad sits at offset 16): +// offset 0: weight_offset (u32, 4 bytes) +// offset 4: (12 bytes implicit padding before vec3u) +// offset 16: _pad (vec3u, 12 bytes) +// offset 28: (4 bytes implicit padding before vec4f) +// offset 32: gamma (vec4f, 16 bytes) +// offset 48: beta (vec4f, 16 bytes) +// total: 64 bytes +struct CnnV3Params4ch { + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad) + float gamma[4]; // offset 32 + float beta[4]; // offset 48 +}; +static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes"); + +// enc1: 8-channel FiLM (split into lo/hi vec4 pairs) +// +// WGSL layout (same header padding as above): +// offset 0: weight_offset (u32, 4 bytes) +// offset 16: _pad (vec3u, 12 bytes) +// offset 32: gamma_lo (vec4f, 16 bytes) +// offset 48: gamma_hi (vec4f, 16 bytes) +// offset 64: beta_lo (vec4f, 16 bytes) +// offset 80: beta_hi (vec4f, 16 bytes) +// total: 96 bytes +struct CnnV3ParamsEnc1 { + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma_lo[4]; // offset 32 + float gamma_hi[4]; // offset 48 + float beta_lo[4]; // offset 64 + float beta_hi[4]; // offset 80 +}; +static_assert(sizeof(CnnV3ParamsEnc1) == 96, + "CnnV3ParamsEnc1 must be 96 bytes"); + +// bottleneck: no FiLM — 4 plain u32s, no alignment gap +struct CnnV3ParamsBn { + uint32_t weight_offset; + uint32_t _pad[3]; +}; +static_assert(sizeof(CnnV3ParamsBn) == 16, "CnnV3ParamsBn must be 16 bytes"); + +// --------------------------------------------------------------------------- +// FiLM conditioning inputs (CPU-side, uploaded via set_film_params each frame) +// --------------------------------------------------------------------------- +struct CNNv3FiLMParams { + float beat_phase = 0.0f; // 0-1 within current beat + float beat_norm = 0.0f; // beat_time / 8.0, normalized 8-beat cycle + float audio_intensity = 0.0f; // peak audio level 0-1 + float style_p0 = 0.0f; // user-defined style param + float style_p1 = 0.0f; // user-defined style param +}; + +class CNNv3Effect : public Effect { + public: + CNNv3Effect(const GpuContext& ctx, const std::vector<std::string>& inputs, + const std::vector<std::string>& outputs, float start_time, + float end_time); + + void declare_nodes(NodeRegistry& registry) override; + + void render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params, + NodeRegistry& nodes) override; + + // Update FiLM conditioning; call before render() each frame. + void set_film_params(const CNNv3FiLMParams& fp); + + private: + // Intermediate node names (prefixed from output[0]) + std::string node_enc0_; + std::string node_enc1_; + std::string node_bottleneck_; + std::string node_dec1_; + + // 5 compute pipelines + ComputePipeline enc0_pipeline_; + ComputePipeline enc1_pipeline_; + ComputePipeline bn_pipeline_; + ComputePipeline dec1_pipeline_; + ComputePipeline dec0_pipeline_; + + // 5 bind groups (rebuilt each render since node views may change) + BindGroup enc0_bg_; + BindGroup enc1_bg_; + BindGroup bn_bg_; + BindGroup dec1_bg_; + BindGroup dec0_bg_; + + // Params uniform buffers (one per pass) + UniformBuffer<CnnV3Params4ch> enc0_params_buf_; + UniformBuffer<CnnV3ParamsEnc1> enc1_params_buf_; + UniformBuffer<CnnV3ParamsBn> bn_params_buf_; + UniformBuffer<CnnV3Params4ch> dec1_params_buf_; + UniformBuffer<CnnV3Params4ch> dec0_params_buf_; + + // Shared packed-f16 weights (storage buffer, read-only in all shaders) + GpuBuffer weights_buf_; + + // Per-pass params shadow (updated by set_film_params, uploaded in render) + CnnV3Params4ch enc0_params_{}; + CnnV3ParamsEnc1 enc1_params_{}; + CnnV3ParamsBn bn_params_{}; + CnnV3Params4ch dec1_params_{}; + CnnV3Params4ch dec0_params_{}; + + void create_pipelines(); + void update_bind_groups(NodeRegistry& nodes); +}; |
