diff options
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.cc')
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.cc | 63 |
1 files changed, 52 insertions, 11 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc index 9c727ba..97e4790 100644 --- a/src/gpu/effects/cnn_v2_effect.cc +++ b/src/gpu/effects/cnn_v2_effect.cc @@ -16,6 +16,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx) : PostProcessEffect(ctx), static_pipeline_(nullptr), static_bind_group_(nullptr), + static_params_buffer_(nullptr), static_features_tex_(nullptr), static_features_view_(nullptr), layer_pipeline_(nullptr), @@ -23,6 +24,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx) input_mip_tex_(nullptr), current_input_view_(nullptr), blend_amount_(1.0f), + mip_level_(0), initialized_(false) { std::memset(input_mip_view_, 0, sizeof(input_mip_view_)); } @@ -31,6 +33,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params) : PostProcessEffect(ctx), static_pipeline_(nullptr), static_bind_group_(nullptr), + static_params_buffer_(nullptr), static_features_tex_(nullptr), static_features_view_(nullptr), layer_pipeline_(nullptr), @@ -38,6 +41,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params) input_mip_tex_(nullptr), current_input_view_(nullptr), blend_amount_(params.blend_amount), + mip_level_(0), initialized_(false) { std::memset(input_mip_view_, 0, sizeof(input_mip_view_)); } @@ -69,12 +73,12 @@ void CNNv2Effect::load_weights() { size_t weights_size = 0; const uint8_t* weights_data = (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size); - if (!weights_data || weights_size < 16) { + if (!weights_data || weights_size < 20) { // Weights not available - effect will skip return; } - // Parse header (16 bytes) + // Parse header const uint32_t* header = (const uint32_t*)weights_data; uint32_t magic = header[0]; uint32_t version = header[1]; @@ -82,10 +86,20 @@ void CNNv2Effect::load_weights() { uint32_t total_weights = header[3]; FATAL_CHECK(magic != 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2' - FATAL_CHECK(version != 1, "Unsupported CNN v2 weights version\n"); + + // Support both version 1 (16-byte header) and version 2 (20-byte header with mip_level) + if (version == 1) { + mip_level_ = 0; // Default for v1 + } else if (version == 2) { + mip_level_ = header[4]; + } else { + FATAL_ERROR("Unsupported CNN v2 weights version: %u\n", version); + } // Parse layer info (20 bytes per layer) - const uint32_t* layer_data = header + 4; + // Offset depends on version: v1=16 bytes (4 u32), v2=20 bytes (5 u32) + const uint32_t header_u32_count = (version == 1) ? 4 : 5; + const uint32_t* layer_data = header + header_u32_count; for (uint32_t i = 0; i < num_layers; ++i) { LayerInfo info; info.kernel_size = layer_data[i * 5 + 0]; @@ -192,6 +206,13 @@ void CNNv2Effect::create_textures() { WGPUTextureView view = wgpuTextureCreateView(tex, &view_desc); layer_views_.push_back(view); } + + // Create uniform buffer for static feature params + WGPUBufferDescriptor params_desc = {}; + params_desc.size = sizeof(StaticFeatureParams); + params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + params_desc.mappedAtCreation = false; + static_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, ¶ms_desc); } void CNNv2Effect::create_pipelines() { @@ -224,8 +245,8 @@ void CNNv2Effect::create_pipelines() { wgpuShaderModuleRelease(static_module); // Create bind group layout for static features compute - // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output - WGPUBindGroupLayoutEntry bgl_entries[5] = {}; + // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output, 5=params + WGPUBindGroupLayoutEntry bgl_entries[6] = {}; // Binding 0: Input texture (mip 0) bgl_entries[0].binding = 0; @@ -258,8 +279,14 @@ void CNNv2Effect::create_pipelines() { bgl_entries[4].storageTexture.format = WGPUTextureFormat_RGBA32Uint; bgl_entries[4].storageTexture.viewDimension = WGPUTextureViewDimension_2D; + // Binding 5: Params (mip_level) + bgl_entries[5].binding = 5; + bgl_entries[5].visibility = WGPUShaderStage_Compute; + bgl_entries[5].buffer.type = WGPUBufferBindingType_Uniform; + bgl_entries[5].buffer.minBindingSize = sizeof(StaticFeatureParams); + WGPUBindGroupLayoutDescriptor bgl_desc = {}; - bgl_desc.entryCount = 5; + bgl_desc.entryCount = 6; bgl_desc.entries = bgl_entries; WGPUBindGroupLayout static_bgl = wgpuDeviceCreateBindGroupLayout(ctx_.device, &bgl_desc); @@ -378,7 +405,7 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { } // Create bind group for static features compute - WGPUBindGroupEntry bg_entries[5] = {}; + WGPUBindGroupEntry bg_entries[6] = {}; // Binding 0: Input (mip 0) bg_entries[0].binding = 0; @@ -386,11 +413,11 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { // Binding 1: Input (mip 1) bg_entries[1].binding = 1; - bg_entries[1].textureView = input_mip_view_[0]; // Use mip 0 for now + bg_entries[1].textureView = input_mip_view_[0]; // Binding 2: Input (mip 2) bg_entries[2].binding = 2; - bg_entries[2].textureView = input_mip_view_[0]; // Use mip 0 for now + bg_entries[2].textureView = (input_mip_view_[1]) ? input_mip_view_[1] : input_mip_view_[0]; // Binding 3: Depth (use input for now, no depth available) bg_entries[3].binding = 3; @@ -400,9 +427,14 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { bg_entries[4].binding = 4; bg_entries[4].textureView = static_features_view_; + // Binding 5: Params + bg_entries[5].binding = 5; + bg_entries[5].buffer = static_params_buffer_; + bg_entries[5].size = sizeof(StaticFeatureParams); + WGPUBindGroupDescriptor bg_desc = {}; bg_desc.layout = wgpuComputePipelineGetBindGroupLayout(static_pipeline_, 0); - bg_desc.entryCount = 5; + bg_desc.entryCount = 6; bg_desc.entries = bg_entries; static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc); @@ -473,6 +505,14 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder, effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_; } + // Update static feature params + StaticFeatureParams static_params; + static_params.mip_level = mip_level_; + static_params.padding[0] = 0; + static_params.padding[1] = 0; + static_params.padding[2] = 0; + wgpuQueueWriteBuffer(ctx_.queue, static_params_buffer_, 0, &static_params, sizeof(static_params)); + // Pass 1: Compute static features WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr); @@ -527,6 +567,7 @@ void CNNv2Effect::cleanup() { if (static_features_view_) wgpuTextureViewRelease(static_features_view_); if (static_features_tex_) wgpuTextureRelease(static_features_tex_); if (static_bind_group_) wgpuBindGroupRelease(static_bind_group_); + if (static_params_buffer_) wgpuBufferRelease(static_params_buffer_); if (static_pipeline_) wgpuComputePipelineRelease(static_pipeline_); if (layer_pipeline_) wgpuComputePipelineRelease(layer_pipeline_); |
