// CNN v3 Effect — U-Net + FiLM inference (5 compute passes) // See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md §7 for shader details. #include "cnn_v3_effect.h" #include "gpu/gpu.h" #include "gpu/shader_composer.h" #include "util/fatal_error.h" #include #include // --------------------------------------------------------------------------- // Weight layout constants — explicit formulas matching WGSL shader comments // --------------------------------------------------------------------------- // // Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases // Layout: OIHW order (out × in × kH × kW), biases appended after conv weights // static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias static const uint32_t kBnWeights = 8 * 8 * 1 + 8; // Conv(8→8,1×1)+bias static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias static const uint32_t kEnc0Offset = 0; static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; static const uint32_t kDec1Offset = kBnOffset + kBnWeights; static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; // Weights buffer size in bytes: f16 values are packed two-per-u32. // Round up to u32 boundary. static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4; // --------------------------------------------------------------------------- // Shader source externs (registered in shaders.cc via InitShaderComposer) // --------------------------------------------------------------------------- extern const char* cnn_v3_enc0_wgsl; extern const char* cnn_v3_enc1_wgsl; extern const char* cnn_v3_bottleneck_wgsl; extern const char* cnn_v3_dec1_wgsl; extern const char* cnn_v3_dec0_wgsl; // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- static WGPUShaderModule make_shader(WGPUDevice device, const char* wgsl) { const std::string composed = ShaderComposer::Get().Compose({"cnn_v3/common"}, wgsl); WGPUShaderSourceWGSL src = {}; src.chain.sType = WGPUSType_ShaderSourceWGSL; src.code = str_view(composed.c_str()); WGPUShaderModuleDescriptor desc = {}; desc.nextInChain = &src.chain; return wgpuDeviceCreateShaderModule(device, &desc); } static WGPUBindGroupLayout make_bgl(WGPUDevice device, const WGPUBindGroupLayoutEntry* entries, uint32_t count) { WGPUBindGroupLayoutDescriptor desc = {}; desc.entryCount = count; desc.entries = entries; return wgpuDeviceCreateBindGroupLayout(device, &desc); } static WGPUComputePipeline make_compute_pipeline(WGPUDevice device, WGPUShaderModule shader, const char* entry, WGPUBindGroupLayout bgl) { WGPUPipelineLayoutDescriptor pl_desc = {}; pl_desc.bindGroupLayoutCount = 1; pl_desc.bindGroupLayouts = &bgl; WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc); WGPUComputePipelineDescriptor pipe_desc = {}; pipe_desc.layout = pl; pipe_desc.compute.module = shader; pipe_desc.compute.entryPoint = str_view(entry); WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &pipe_desc); wgpuPipelineLayoutRelease(pl); return pipe; } // BGL entry helpers static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.texture.sampleType = WGPUTextureSampleType_Uint; e.texture.viewDimension = WGPUTextureViewDimension_2D; return e; } static WGPUBindGroupLayoutEntry bgl_float_tex(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.texture.sampleType = WGPUTextureSampleType_Float; e.texture.viewDimension = WGPUTextureViewDimension_2D; return e; } static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage; return e; } static WGPUBindGroupLayoutEntry bgl_uniform_buf(uint32_t binding, uint64_t min_size) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.buffer.type = WGPUBufferBindingType_Uniform; e.buffer.minBindingSize = min_size; return e; } static WGPUBindGroupLayoutEntry bgl_storage_tex_write( uint32_t binding, WGPUTextureFormat fmt) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly; e.storageTexture.format = fmt; e.storageTexture.viewDimension = WGPUTextureViewDimension_2D; return e; } // --------------------------------------------------------------------------- // Constructor // --------------------------------------------------------------------------- CNNv3Effect::CNNv3Effect(const GpuContext& ctx, const std::vector& inputs, const std::vector& outputs, float start_time, float end_time) : Effect(ctx, inputs, outputs, start_time, end_time) { HEADLESS_RETURN_IF_NULL(ctx_.device); const std::string& prefix = outputs.empty() ? std::string("cnn_v3") : outputs[0]; node_enc0_ = prefix + "_enc0"; node_enc1_ = prefix + "_enc1"; node_bottleneck_ = prefix + "_bottleneck"; node_dec1_ = prefix + "_dec1"; // Allocate zeroed weights buffer (f16 pairs packed as u32). // Weights are zero-initialized; load_weights() can fill from file later. weights_buf_ = gpu_create_buffer( ctx_.device, kWeightsBufBytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); // Initialize uniform buffers. enc0_params_buf_.init(ctx_.device); enc1_params_buf_.init(ctx_.device); bn_params_buf_.init(ctx_.device); dec1_params_buf_.init(ctx_.device); dec0_params_buf_.init(ctx_.device); // Set weight offsets (FiLM γ/β default to identity: γ=1, β=0). enc0_params_.weight_offset = kEnc0Offset; for (int i = 0; i < 4; ++i) { enc0_params_.gamma[i] = 1.0f; } enc1_params_.weight_offset = kEnc1Offset; for (int i = 0; i < 4; ++i) { enc1_params_.gamma_lo[i] = 1.0f; enc1_params_.gamma_hi[i] = 1.0f; } bn_params_.weight_offset = kBnOffset; dec1_params_.weight_offset = kDec1Offset; for (int i = 0; i < 4; ++i) { dec1_params_.gamma[i] = 1.0f; } dec0_params_.weight_offset = kDec0Offset; for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; } create_pipelines(); } // --------------------------------------------------------------------------- // declare_nodes // --------------------------------------------------------------------------- void CNNv3Effect::declare_nodes(NodeRegistry& registry) { const int W = registry.default_width(); const int H = registry.default_height(); // enc0_tex: rgba16float full-res registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, W, H); // enc1_tex: rgba32uint half-res — shaders use textureDimensions() for bounds registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // bottleneck_tex: rgba32uint quarter-res registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); // dec1_tex: rgba16float half-res registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2); // output_tex: rgba16float full-res (the declared output_nodes_[0]) } // --------------------------------------------------------------------------- // set_film_params — simple linear mapping, no MLP yet // --------------------------------------------------------------------------- void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes) { wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes); } void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { // Identity + audio/beat modulation. // Replace with FiLM MLP output once training is done. const float a = fp.audio_intensity; const float b = fp.beat_phase; for (int i = 0; i < 4; ++i) { enc0_params_.gamma[i] = 1.0f + a * 0.5f; enc0_params_.beta[i] = b * 0.1f; } for (int i = 0; i < 4; ++i) { enc1_params_.gamma_lo[i] = 1.0f + a * 0.3f; enc1_params_.gamma_hi[i] = 1.0f + a * 0.3f; enc1_params_.beta_lo[i] = fp.beat_norm * 0.1f; enc1_params_.beta_hi[i] = fp.beat_norm * 0.1f; } for (int i = 0; i < 4; ++i) { dec1_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; dec1_params_.beta[i] = fp.style_p1 * 0.1f; dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; dec0_params_.beta[i] = fp.style_p1 * 0.1f; } } // --------------------------------------------------------------------------- // render // --------------------------------------------------------------------------- void CNNv3Effect::render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params, NodeRegistry& nodes) { // Upload params uniforms. enc0_params_buf_.update(ctx_.queue, enc0_params_); enc1_params_buf_.update(ctx_.queue, enc1_params_); bn_params_buf_.update(ctx_.queue, bn_params_); dec1_params_buf_.update(ctx_.queue, dec1_params_); dec0_params_buf_.update(ctx_.queue, dec0_params_); update_bind_groups(nodes); const int W = (int)params.resolution.x; const int H = (int)params.resolution.y; // Dispatch helper: ceil(dim / 8) workgroups auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, int w, int h) { WGPUComputePassDescriptor pass_desc = {}; WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); wgpuComputePassEncoderSetPipeline(pass, pipe); wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( pass, (uint32_t)((w + 7) / 8), (uint32_t)((h + 7) / 8), 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); }; dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H); dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2); dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4); dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2); dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H); } // --------------------------------------------------------------------------- // create_pipelines // --------------------------------------------------------------------------- void CNNv3Effect::create_pipelines() { HEADLESS_RETURN_IF_NULL(ctx_.device); WGPUDevice dev = ctx_.device; // --- enc0 --- // B0: feat_tex0 (u32), B1: feat_tex1 (u32), B2: weights (storage), // B3: params (uniform), B4: enc0_out (storage_tex rgba16float write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc0_wgsl); enc0_pipeline_.set(make_compute_pipeline(dev, sh, "enc0_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- enc1 --- // B0: enc0_tex (f32), B1: weights (storage), // B2: params (uniform), B3: enc1_out (storage_tex rgba32uint write) { WGPUBindGroupLayoutEntry e[4] = { bgl_float_tex(0), bgl_storage_buf(1), bgl_uniform_buf(2, sizeof(CnnV3ParamsEnc1)), bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc1_wgsl); enc1_pipeline_.set(make_compute_pipeline(dev, sh, "enc1_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- bottleneck --- // B0: enc1_tex (u32), B1: weights (storage), // B2: params (uniform), B3: bottleneck_out (storage_tex rgba32uint write) { WGPUBindGroupLayoutEntry e[4] = { bgl_uint_tex(0), bgl_storage_buf(1), bgl_uniform_buf(2, sizeof(CnnV3ParamsBn)), bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); WGPUShaderModule sh = make_shader(dev, cnn_v3_bottleneck_wgsl); bn_pipeline_.set(make_compute_pipeline(dev, sh, "bottleneck_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- dec1 --- // B0: bottleneck_tex (u32), B1: enc1_tex (u32), B2: weights (storage), // B3: params (uniform), B4: dec1_out (storage_tex rgba16float write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_dec1_wgsl); dec1_pipeline_.set(make_compute_pipeline(dev, sh, "dec1_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- dec0 --- // B0: dec1_tex (f32), B1: enc0_tex (f32), B2: weights (storage), // B3: params (uniform), B4: output_tex (storage_tex rgba16float write) { WGPUBindGroupLayoutEntry e[5] = { bgl_float_tex(0), bgl_float_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_dec0_wgsl); dec0_pipeline_.set(make_compute_pipeline(dev, sh, "dec0_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } } // --------------------------------------------------------------------------- // update_bind_groups — rebuilt each frame (node views may be recreated) // --------------------------------------------------------------------------- // Helper: set a texture view binding entry. static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding, WGPUTextureView view) { e = {}; e.binding = binding; e.textureView = view; } // Helper: set a buffer binding entry. static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, uint64_t size) { e = {}; e.binding = binding; e.buffer = buf; e.size = size; } void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { WGPUDevice dev = ctx_.device; WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); WGPUTextureView enc0_view = nodes.get_view(node_enc0_); WGPUTextureView enc1_view = nodes.get_view(node_enc1_); WGPUTextureView bn_view = nodes.get_view(node_bottleneck_); WGPUTextureView dec1_view = nodes.get_view(node_dec1_); WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); WGPUBuffer wb = weights_buf_.buffer; auto make_bg = [&](WGPUComputePipeline pipe, WGPUBindGroupEntry* e, uint32_t count) -> WGPUBindGroup { WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pipe, 0); WGPUBindGroupDescriptor desc = {}; desc.layout = bgl; desc.entryCount = count; desc.entries = e; WGPUBindGroup bg = wgpuDeviceCreateBindGroup(dev, &desc); wgpuBindGroupLayoutRelease(bgl); return bg; }; // enc0: feat_tex0(B0), feat_tex1(B1), weights(B2), params(B3), enc0_out(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, feat0_view); bg_tex(e[1], 1, feat1_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); bg_tex(e[4], 4, enc0_view); enc0_bg_.set(make_bg(enc0_pipeline_.get(), e, 5)); } // enc1: enc0_tex(B0), weights(B1), params(B2), enc1_out(B3) { WGPUBindGroupEntry e[4] = {}; bg_tex(e[0], 0, enc0_view); bg_buf(e[1], 1, wb, kWeightsBufBytes); bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3ParamsEnc1)); bg_tex(e[3], 3, enc1_view); enc1_bg_.set(make_bg(enc1_pipeline_.get(), e, 4)); } // bottleneck: enc1_tex(B0), weights(B1), params(B2), bn_out(B3) { WGPUBindGroupEntry e[4] = {}; bg_tex(e[0], 0, enc1_view); bg_buf(e[1], 1, wb, kWeightsBufBytes); bg_buf(e[2], 2, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn)); bg_tex(e[3], 3, bn_view); bn_bg_.set(make_bg(bn_pipeline_.get(), e, 4)); } // dec1: bn_tex(B0), enc1_tex(B1), weights(B2), params(B3), dec1_out(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, bn_view); bg_tex(e[1], 1, enc1_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, dec1_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); bg_tex(e[4], 4, dec1_view); dec1_bg_.set(make_bg(dec1_pipeline_.get(), e, 5)); } // dec0: dec1_tex(B0), enc0_tex(B1), weights(B2), params(B3), output(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, dec1_view); bg_tex(e[1], 1, enc0_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, dec0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); bg_tex(e[4], 4, out_view); dec0_bg_.set(make_bg(dec0_pipeline_.get(), e, 5)); } }