diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 16:48:02 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 16:48:02 +0100 |
| commit | ced248b0a8973db6d11b79e8290e2f5bb17ffcaa (patch) | |
| tree | 938899c02408895124fb8e9751dbd36b50d4a0ba | |
| parent | 3eed58f0cd602298681a2313946ba2a5824d6c6e (diff) | |
CNN v2: Add mip-level support to runtime effect
Binary format v2 includes mip_level in header (20 bytes, was 16).
Effect reads mip_level and passes to static features shader via uniform.
Shader samples from correct mip texture based on mip_level.
Changes:
- export_cnn_v2_weights.py: Header v2 with mip_level field
- cnn_v2_effect.h: Add StaticFeatureParams, mip_level member, params buffer
- cnn_v2_effect.cc: Read mip_level from weights, create/bind params buffer, update per-frame
- cnn_v2_static.wgsl: Accept params uniform, sample from selected mip level
Binary format v2:
- Header: 20 bytes (magic, version=2, num_layers, total_weights, mip_level)
- Backward compatible: v1 weights load with mip_level=0
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.cc | 63 | ||||
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.h | 7 | ||||
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 16 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl | 25 |
4 files changed, 88 insertions, 23 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc index 9c727ba..97e4790 100644 --- a/src/gpu/effects/cnn_v2_effect.cc +++ b/src/gpu/effects/cnn_v2_effect.cc @@ -16,6 +16,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx) : PostProcessEffect(ctx), static_pipeline_(nullptr), static_bind_group_(nullptr), + static_params_buffer_(nullptr), static_features_tex_(nullptr), static_features_view_(nullptr), layer_pipeline_(nullptr), @@ -23,6 +24,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx) input_mip_tex_(nullptr), current_input_view_(nullptr), blend_amount_(1.0f), + mip_level_(0), initialized_(false) { std::memset(input_mip_view_, 0, sizeof(input_mip_view_)); } @@ -31,6 +33,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params) : PostProcessEffect(ctx), static_pipeline_(nullptr), static_bind_group_(nullptr), + static_params_buffer_(nullptr), static_features_tex_(nullptr), static_features_view_(nullptr), layer_pipeline_(nullptr), @@ -38,6 +41,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params) input_mip_tex_(nullptr), current_input_view_(nullptr), blend_amount_(params.blend_amount), + mip_level_(0), initialized_(false) { std::memset(input_mip_view_, 0, sizeof(input_mip_view_)); } @@ -69,12 +73,12 @@ void CNNv2Effect::load_weights() { size_t weights_size = 0; const uint8_t* weights_data = (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size); - if (!weights_data || weights_size < 16) { + if (!weights_data || weights_size < 20) { // Weights not available - effect will skip return; } - // Parse header (16 bytes) + // Parse header const uint32_t* header = (const uint32_t*)weights_data; uint32_t magic = header[0]; uint32_t version = header[1]; @@ -82,10 +86,20 @@ void CNNv2Effect::load_weights() { uint32_t total_weights = header[3]; FATAL_CHECK(magic != 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2' - FATAL_CHECK(version != 1, "Unsupported CNN v2 weights version\n"); + + // Support both version 1 (16-byte header) and version 2 (20-byte header with mip_level) + if (version == 1) { + mip_level_ = 0; // Default for v1 + } else if (version == 2) { + mip_level_ = header[4]; + } else { + FATAL_ERROR("Unsupported CNN v2 weights version: %u\n", version); + } // Parse layer info (20 bytes per layer) - const uint32_t* layer_data = header + 4; + // Offset depends on version: v1=16 bytes (4 u32), v2=20 bytes (5 u32) + const uint32_t header_u32_count = (version == 1) ? 4 : 5; + const uint32_t* layer_data = header + header_u32_count; for (uint32_t i = 0; i < num_layers; ++i) { LayerInfo info; info.kernel_size = layer_data[i * 5 + 0]; @@ -192,6 +206,13 @@ void CNNv2Effect::create_textures() { WGPUTextureView view = wgpuTextureCreateView(tex, &view_desc); layer_views_.push_back(view); } + + // Create uniform buffer for static feature params + WGPUBufferDescriptor params_desc = {}; + params_desc.size = sizeof(StaticFeatureParams); + params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + params_desc.mappedAtCreation = false; + static_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, ¶ms_desc); } void CNNv2Effect::create_pipelines() { @@ -224,8 +245,8 @@ void CNNv2Effect::create_pipelines() { wgpuShaderModuleRelease(static_module); // Create bind group layout for static features compute - // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output - WGPUBindGroupLayoutEntry bgl_entries[5] = {}; + // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output, 5=params + WGPUBindGroupLayoutEntry bgl_entries[6] = {}; // Binding 0: Input texture (mip 0) bgl_entries[0].binding = 0; @@ -258,8 +279,14 @@ void CNNv2Effect::create_pipelines() { bgl_entries[4].storageTexture.format = WGPUTextureFormat_RGBA32Uint; bgl_entries[4].storageTexture.viewDimension = WGPUTextureViewDimension_2D; + // Binding 5: Params (mip_level) + bgl_entries[5].binding = 5; + bgl_entries[5].visibility = WGPUShaderStage_Compute; + bgl_entries[5].buffer.type = WGPUBufferBindingType_Uniform; + bgl_entries[5].buffer.minBindingSize = sizeof(StaticFeatureParams); + WGPUBindGroupLayoutDescriptor bgl_desc = {}; - bgl_desc.entryCount = 5; + bgl_desc.entryCount = 6; bgl_desc.entries = bgl_entries; WGPUBindGroupLayout static_bgl = wgpuDeviceCreateBindGroupLayout(ctx_.device, &bgl_desc); @@ -378,7 +405,7 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { } // Create bind group for static features compute - WGPUBindGroupEntry bg_entries[5] = {}; + WGPUBindGroupEntry bg_entries[6] = {}; // Binding 0: Input (mip 0) bg_entries[0].binding = 0; @@ -386,11 +413,11 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { // Binding 1: Input (mip 1) bg_entries[1].binding = 1; - bg_entries[1].textureView = input_mip_view_[0]; // Use mip 0 for now + bg_entries[1].textureView = input_mip_view_[0]; // Binding 2: Input (mip 2) bg_entries[2].binding = 2; - bg_entries[2].textureView = input_mip_view_[0]; // Use mip 0 for now + bg_entries[2].textureView = (input_mip_view_[1]) ? input_mip_view_[1] : input_mip_view_[0]; // Binding 3: Depth (use input for now, no depth available) bg_entries[3].binding = 3; @@ -400,9 +427,14 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { bg_entries[4].binding = 4; bg_entries[4].textureView = static_features_view_; + // Binding 5: Params + bg_entries[5].binding = 5; + bg_entries[5].buffer = static_params_buffer_; + bg_entries[5].size = sizeof(StaticFeatureParams); + WGPUBindGroupDescriptor bg_desc = {}; bg_desc.layout = wgpuComputePipelineGetBindGroupLayout(static_pipeline_, 0); - bg_desc.entryCount = 5; + bg_desc.entryCount = 6; bg_desc.entries = bg_entries; static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc); @@ -473,6 +505,14 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder, effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_; } + // Update static feature params + StaticFeatureParams static_params; + static_params.mip_level = mip_level_; + static_params.padding[0] = 0; + static_params.padding[1] = 0; + static_params.padding[2] = 0; + wgpuQueueWriteBuffer(ctx_.queue, static_params_buffer_, 0, &static_params, sizeof(static_params)); + // Pass 1: Compute static features WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr); @@ -527,6 +567,7 @@ void CNNv2Effect::cleanup() { if (static_features_view_) wgpuTextureViewRelease(static_features_view_); if (static_features_tex_) wgpuTextureRelease(static_features_tex_); if (static_bind_group_) wgpuBindGroupRelease(static_bind_group_); + if (static_params_buffer_) wgpuBufferRelease(static_params_buffer_); if (static_pipeline_) wgpuComputePipelineRelease(static_pipeline_); if (layer_pipeline_) wgpuComputePipelineRelease(layer_pipeline_); diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h index 28aec8c..47dedf5 100644 --- a/src/gpu/effects/cnn_v2_effect.h +++ b/src/gpu/effects/cnn_v2_effect.h @@ -47,6 +47,11 @@ private: float blend_amount; }; + struct StaticFeatureParams { + uint32_t mip_level; + uint32_t padding[3]; + }; + void create_textures(); void create_pipelines(); void load_weights(); @@ -55,6 +60,7 @@ private: // Static features compute WGPUComputePipeline static_pipeline_; WGPUBindGroup static_bind_group_; + WGPUBuffer static_params_buffer_; WGPUTexture static_features_tex_; WGPUTextureView static_features_view_; @@ -75,5 +81,6 @@ private: float blend_amount_ = 1.0f; bool beat_modulated_ = false; float beat_scale_ = 1.0f; + uint32_t mip_level_ = 0; bool initialized_; }; diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index 9e9e352..1086516 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -16,11 +16,12 @@ def export_weights_binary(checkpoint_path, output_path): """Export CNN v2 weights to binary format. Binary format: - Header (16 bytes): + Header (20 bytes): uint32 magic ('CNN2') - uint32 version (1) + uint32 version (2) uint32 num_layers uint32 total_weights (f16 count) + uint32 mip_level (0-3) LayerInfo × num_layers (20 bytes each): uint32 kernel_size @@ -107,19 +108,20 @@ def export_weights_binary(checkpoint_path, output_path): print(f" Total layers: {len(layers)}") print(f" Total weights: {len(all_weights_f16)} (f16)") print(f" Packed: {len(weights_u32)} u32") - print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes") + print(f" Binary size: {20 + len(layers) * 20 + len(weights_u32) * 4} bytes") # Write binary file output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'wb') as f: - # Header (16 bytes) - f.write(struct.pack('<4sIII', + # Header (20 bytes) - version 2 with mip_level + f.write(struct.pack('<4sIIII', b'CNN2', # magic - 1, # version + 2, # version (bumped to 2) len(layers), # num_layers - len(all_weights_f16))) # total_weights (f16 count) + len(all_weights_f16), # total_weights (f16 count) + mip_level)) # mip_level # Layer info (20 bytes per layer) for layer in layers: diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl index 7a9e6de..f71fad2 100644 --- a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl @@ -1,13 +1,19 @@ // CNN v2 Static Features Compute Shader // Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] -// p0-p3: Parametric features (currently RGBD from mip0, could be mip1/2, gradients, etc.) +// p0-p3: Parametric features from specified mip level (0=mip0, 1=mip1, 2=mip2, 3=mip3) // Note: Input image RGBD (mip0) fed separately to Layer 0 +struct StaticFeatureParams { + mip_level: u32, + padding: vec3<u32>, +} + @group(0) @binding(0) var input_tex: texture_2d<f32>; @group(0) @binding(1) var input_tex_mip1: texture_2d<f32>; @group(0) @binding(2) var input_tex_mip2: texture_2d<f32>; @group(0) @binding(3) var depth_tex: texture_2d<f32>; @group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>; +@group(0) @binding(5) var<uniform> params: StaticFeatureParams; @compute @workgroup_size(8, 8) fn main(@builtin(global_invocation_id) id: vec3<u32>) { @@ -18,10 +24,19 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { return; } - // Parametric features (p0-p3) - // TODO: Experiment with mip1 grayscale, Sobel gradients, etc. - // For now, use RGBD from mip 0 (same as input, but could differ) - let rgba = textureLoad(input_tex, coord, 0); + // Parametric features (p0-p3) - sample from specified mip level + var rgba: vec4<f32>; + if (params.mip_level == 0u) { + rgba = textureLoad(input_tex, coord, 0); + } else if (params.mip_level == 1u) { + rgba = textureLoad(input_tex_mip1, coord, 0); + } else if (params.mip_level == 2u) { + rgba = textureLoad(input_tex_mip2, coord, 0); + } else { + // Mip 3 or higher: use mip 2 as fallback + rgba = textureLoad(input_tex_mip2, coord, 0); + } + let p0 = rgba.r; let p1 = rgba.g; let p2 = rgba.b; |
