// CNN v3 Effect — U-Net + FiLM inference (5 compute passes) // See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md for shader details. #include "cnn_v3_effect.h" #if defined(USE_TEST_ASSETS) #include "test_assets.h" #else #include "generated/assets.h" #endif #include "gpu/gpu.h" #include "gpu/shader_composer.h" #include "util/asset_manager.h" #include "util/fatal_error.h" #include #include // --------------------------------------------------------------------------- // Weight layout constants — enc_channels=[8,16] // // Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases // Layout: OIHW order (out × in × kH × kW), biases appended // --------------------------------------------------------------------------- static const uint32_t kEnc0Weights = 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448 static const uint32_t kEnc1Weights = 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168 static const uint32_t kBnWeights = 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320 static const uint32_t kDec1Weights = 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312 static const uint32_t kDec0Weights = 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580 static const uint32_t kEnc0Offset = 0; static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; static const uint32_t kDec1Offset = kBnOffset + kBnWeights; static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; // = 1448 + 1168 + 2320 + 2312 + 580 = 7828 f16 static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4; // --------------------------------------------------------------------------- // Shader source externs // --------------------------------------------------------------------------- extern const char* cnn_v3_enc0_wgsl; extern const char* cnn_v3_enc1_wgsl; extern const char* cnn_v3_bottleneck_wgsl; extern const char* cnn_v3_dec1_wgsl; extern const char* cnn_v3_dec0_wgsl; // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- static WGPUShaderModule make_shader(WGPUDevice device, const char* wgsl) { const std::string composed = ShaderComposer::Get().Compose({"cnn_v3/common"}, wgsl); WGPUShaderSourceWGSL src = {}; src.chain.sType = WGPUSType_ShaderSourceWGSL; src.code = str_view(composed.c_str()); WGPUShaderModuleDescriptor desc = {}; desc.nextInChain = &src.chain; return wgpuDeviceCreateShaderModule(device, &desc); } static WGPUBindGroupLayout make_bgl(WGPUDevice device, const WGPUBindGroupLayoutEntry* entries, uint32_t count) { WGPUBindGroupLayoutDescriptor desc = {}; desc.entryCount = count; desc.entries = entries; return wgpuDeviceCreateBindGroupLayout(device, &desc); } static WGPUComputePipeline make_compute_pipeline(WGPUDevice device, WGPUShaderModule shader, const char* entry, WGPUBindGroupLayout bgl) { WGPUPipelineLayoutDescriptor pl_desc = {}; pl_desc.bindGroupLayoutCount = 1; pl_desc.bindGroupLayouts = &bgl; WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc); WGPUComputePipelineDescriptor pipe_desc = {}; pipe_desc.layout = pl; pipe_desc.compute.module = shader; pipe_desc.compute.entryPoint = str_view(entry); WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &pipe_desc); wgpuPipelineLayoutRelease(pl); return pipe; } // BGL entry helpers static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.texture.sampleType = WGPUTextureSampleType_Uint; e.texture.viewDimension = WGPUTextureViewDimension_2D; return e; } static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage; return e; } static WGPUBindGroupLayoutEntry bgl_uniform_buf(uint32_t binding, uint64_t min_size) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.buffer.type = WGPUBufferBindingType_Uniform; e.buffer.minBindingSize = min_size; return e; } static WGPUBindGroupLayoutEntry bgl_storage_tex_write( uint32_t binding, WGPUTextureFormat fmt) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; e.visibility = WGPUShaderStage_Compute; e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly; e.storageTexture.format = fmt; e.storageTexture.viewDimension = WGPUTextureViewDimension_2D; return e; } // --------------------------------------------------------------------------- // Constructor // --------------------------------------------------------------------------- CNNv3Effect::CNNv3Effect(const GpuContext& ctx, const std::vector& inputs, const std::vector& outputs, float start_time, float end_time) : Effect(ctx, inputs, outputs, start_time, end_time) { HEADLESS_RETURN_IF_NULL(ctx_.device); const std::string& prefix = outputs.empty() ? std::string("cnn_v3") : outputs[0]; node_enc0_ = prefix + "_enc0"; node_enc1_lo_ = prefix + "_enc1_lo"; node_enc1_hi_ = prefix + "_enc1_hi"; node_bn_lo_ = prefix + "_bn_lo"; node_bn_hi_ = prefix + "_bn_hi"; node_dec1_ = prefix + "_dec1"; weights_buf_ = gpu_create_buffer( ctx_.device, kWeightsBufBytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); enc0_params_buf_.init(ctx_.device); enc1_params_buf_.init(ctx_.device); bn_params_buf_.init(ctx_.device); dec1_params_buf_.init(ctx_.device); dec0_params_buf_.init(ctx_.device); // Set weight offsets; FiLM γ/β default to identity (γ=1, β=0). enc0_params_.weight_offset = kEnc0Offset; for (int i = 0; i < 4; ++i) { enc0_params_.gamma_lo[i] = 1.0f; enc0_params_.gamma_hi[i] = 1.0f; } enc1_params_.weight_offset = kEnc1Offset; for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f; } bn_params_.weight_offset = kBnOffset; dec1_params_.weight_offset = kDec1Offset; for (int i = 0; i < 4; ++i) { dec1_params_.gamma_lo[i] = 1.0f; dec1_params_.gamma_hi[i] = 1.0f; } dec0_params_.weight_offset = kDec0Offset; for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; } create_pipelines(); size_t weights_size = 0; const void* weights_data = GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3, &weights_size); if (weights_data && weights_size == kWeightsBufBytes) { upload_weights(ctx_.queue, weights_data, (uint32_t)weights_size); } } // --------------------------------------------------------------------------- // declare_nodes // --------------------------------------------------------------------------- void CNNv3Effect::declare_nodes(NodeRegistry& registry) { const int W = registry.default_width(); const int H = registry.default_height(); // enc0: rgba32uint full-res (8ch packed f16) registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H); // enc1: two rgba32uint half-res (8ch each = 16ch total) registry.declare_node(node_enc1_lo_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); registry.declare_node(node_enc1_hi_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // bottleneck: two rgba32uint quarter-res (8ch each = 16ch total) registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); // dec1: rgba32uint half-res (8ch packed f16) registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // output_nodes_[0]: rgba16float full-res — declared externally by caller } // --------------------------------------------------------------------------- // upload_weights / set_film_params // --------------------------------------------------------------------------- void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes) { wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes); } void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { const float a = fp.audio_intensity; const float b = fp.beat_phase; for (int i = 0; i < 4; ++i) { enc0_params_.gamma_lo[i] = 1.0f + a * 0.5f; enc0_params_.gamma_hi[i] = 1.0f + a * 0.5f; enc0_params_.beta_lo[i] = b * 0.1f; enc0_params_.beta_hi[i] = b * 0.1f; } for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f + a * 0.3f; enc1_params_.beta[i] = fp.beat_norm * 0.1f; } for (int i = 0; i < 4; ++i) { dec1_params_.gamma_lo[i] = 1.0f + fp.style_p0 * 0.5f; dec1_params_.gamma_hi[i] = 1.0f + fp.style_p0 * 0.5f; dec1_params_.beta_lo[i] = fp.style_p1 * 0.1f; dec1_params_.beta_hi[i] = fp.style_p1 * 0.1f; dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; dec0_params_.beta[i] = fp.style_p1 * 0.1f; } } // --------------------------------------------------------------------------- // render // --------------------------------------------------------------------------- void CNNv3Effect::render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params, NodeRegistry& nodes) { enc0_params_buf_.update(ctx_.queue, enc0_params_); enc1_params_buf_.update(ctx_.queue, enc1_params_); bn_params_buf_.update(ctx_.queue, bn_params_); dec1_params_buf_.update(ctx_.queue, dec1_params_); dec0_params_buf_.update(ctx_.queue, dec0_params_); update_bind_groups(nodes); const int W = (int)params.resolution.x; const int H = (int)params.resolution.y; auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, int w, int h) { WGPUComputePassDescriptor pass_desc = {}; WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); wgpuComputePassEncoderSetPipeline(pass, pipe); wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( pass, (uint32_t)((w + 7) / 8), (uint32_t)((h + 7) / 8), 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); }; dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H); dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2); dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4); dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2); dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H); } // --------------------------------------------------------------------------- // create_pipelines // --------------------------------------------------------------------------- void CNNv3Effect::create_pipelines() { HEADLESS_RETURN_IF_NULL(ctx_.device); WGPUDevice dev = ctx_.device; // --- enc0 --- // B0: feat_tex0 (u32), B1: feat_tex1 (u32), B2: weights (storage), // B3: params (uniform, 96B), B4: enc0_out (rgba32uint write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3Params8ch)), bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc0_wgsl); enc0_pipeline_.set(make_compute_pipeline(dev, sh, "enc0_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- enc1 --- // B0: enc0_tex (u32), B1: weights (storage), // B2: params (uniform, 160B), B3: enc1_out_lo (rgba32uint write), // B4: enc1_out_hi (rgba32uint write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_storage_buf(1), bgl_uniform_buf(2, sizeof(CnnV3Params16ch)), bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc1_wgsl); enc1_pipeline_.set(make_compute_pipeline(dev, sh, "enc1_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- bottleneck --- // B0: enc1_tex_lo (u32), B1: enc1_tex_hi (u32), B2: weights (storage), // B3: params (uniform, 16B), B4: bn_out_lo (rgba32uint write), // B5: bn_out_hi (rgba32uint write) { WGPUBindGroupLayoutEntry e[6] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3ParamsBn)), bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), bgl_storage_tex_write(5, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 6); WGPUShaderModule sh = make_shader(dev, cnn_v3_bottleneck_wgsl); bn_pipeline_.set(make_compute_pipeline(dev, sh, "bottleneck_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- dec1 --- // B0: bn_tex_lo (u32), B1: bn_tex_hi (u32), // B2: enc1_tex_lo (u32), B3: enc1_tex_hi (u32), // B4: weights (storage), B5: params (uniform, 96B), // B6: dec1_out (rgba32uint write) { WGPUBindGroupLayoutEntry e[7] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_uint_tex(2), bgl_uint_tex(3), bgl_storage_buf(4), bgl_uniform_buf(5, sizeof(CnnV3Params8ch)), bgl_storage_tex_write(6, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 7); WGPUShaderModule sh = make_shader(dev, cnn_v3_dec1_wgsl); dec1_pipeline_.set(make_compute_pipeline(dev, sh, "dec1_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } // --- dec0 --- // B0: dec1_tex (u32), B1: enc0_tex (u32), B2: weights (storage), // B3: params (uniform, 64B), B4: output_tex (rgba16float write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_dec0_wgsl); dec0_pipeline_.set(make_compute_pipeline(dev, sh, "dec0_main", bgl)); wgpuShaderModuleRelease(sh); wgpuBindGroupLayoutRelease(bgl); } } // --------------------------------------------------------------------------- // update_bind_groups — rebuilt each frame (node views may be recreated) // --------------------------------------------------------------------------- static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding, WGPUTextureView view) { e = {}; e.binding = binding; e.textureView = view; } static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, uint64_t size) { e = {}; e.binding = binding; e.buffer = buf; e.size = size; } void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { WGPUDevice dev = ctx_.device; WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); WGPUTextureView enc0_view = nodes.get_view(node_enc0_); WGPUTextureView enc1_lo_view = nodes.get_view(node_enc1_lo_); WGPUTextureView enc1_hi_view = nodes.get_view(node_enc1_hi_); WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_); WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_); WGPUTextureView dec1_view = nodes.get_view(node_dec1_); WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); WGPUBuffer wb = weights_buf_.buffer; auto make_bg = [&](WGPUComputePipeline pipe, WGPUBindGroupEntry* e, uint32_t count) -> WGPUBindGroup { WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pipe, 0); WGPUBindGroupDescriptor desc = {}; desc.layout = bgl; desc.entryCount = count; desc.entries = e; WGPUBindGroup bg = wgpuDeviceCreateBindGroup(dev, &desc); wgpuBindGroupLayoutRelease(bgl); return bg; }; // enc0: feat0(B0), feat1(B1), weights(B2), params(B3), enc0_out(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, feat0_view); bg_tex(e[1], 1, feat1_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params8ch)); bg_tex(e[4], 4, enc0_view); enc0_bg_.replace(make_bg(enc0_pipeline_.get(), e, 5)); } // enc1: enc0(B0), weights(B1), params(B2), enc1_lo(B3), enc1_hi(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, enc0_view); bg_buf(e[1], 1, wb, kWeightsBufBytes); bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3Params16ch)); bg_tex(e[3], 3, enc1_lo_view); bg_tex(e[4], 4, enc1_hi_view); enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 5)); } // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), bn_hi(B5) { WGPUBindGroupEntry e[6] = {}; bg_tex(e[0], 0, enc1_lo_view); bg_tex(e[1], 1, enc1_hi_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn)); bg_tex(e[4], 4, bn_lo_view); bg_tex(e[5], 5, bn_hi_view); bn_bg_.replace(make_bg(bn_pipeline_.get(), e, 6)); } // dec1: bn_lo(B0), bn_hi(B1), enc1_lo(B2), enc1_hi(B3), // weights(B4), params(B5), dec1_out(B6) { WGPUBindGroupEntry e[7] = {}; bg_tex(e[0], 0, bn_lo_view); bg_tex(e[1], 1, bn_hi_view); bg_tex(e[2], 2, enc1_lo_view); bg_tex(e[3], 3, enc1_hi_view); bg_buf(e[4], 4, wb, kWeightsBufBytes); bg_buf(e[5], 5, dec1_params_buf_.get().buffer, sizeof(CnnV3Params8ch)); bg_tex(e[6], 6, dec1_view); dec1_bg_.replace(make_bg(dec1_pipeline_.get(), e, 7)); } // dec0: dec1(B0), enc0(B1), weights(B2), params(B3), output(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, dec1_view); bg_tex(e[1], 1, enc0_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); bg_buf(e[3], 3, dec0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); bg_tex(e[4], 4, out_view); dec0_bg_.replace(make_bg(dec0_pipeline_.get(), e, 5)); } }