summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/gpu/effects/cnn_v2_effect.cc63
-rw-r--r--src/gpu/effects/cnn_v2_effect.h7
-rwxr-xr-xtraining/export_cnn_v2_weights.py16
-rw-r--r--workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl25
4 files changed, 88 insertions, 23 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc
index 9c727ba..97e4790 100644
--- a/src/gpu/effects/cnn_v2_effect.cc
+++ b/src/gpu/effects/cnn_v2_effect.cc
@@ -16,6 +16,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
: PostProcessEffect(ctx),
static_pipeline_(nullptr),
static_bind_group_(nullptr),
+ static_params_buffer_(nullptr),
static_features_tex_(nullptr),
static_features_view_(nullptr),
layer_pipeline_(nullptr),
@@ -23,6 +24,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
input_mip_tex_(nullptr),
current_input_view_(nullptr),
blend_amount_(1.0f),
+ mip_level_(0),
initialized_(false) {
std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
}
@@ -31,6 +33,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params)
: PostProcessEffect(ctx),
static_pipeline_(nullptr),
static_bind_group_(nullptr),
+ static_params_buffer_(nullptr),
static_features_tex_(nullptr),
static_features_view_(nullptr),
layer_pipeline_(nullptr),
@@ -38,6 +41,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params)
input_mip_tex_(nullptr),
current_input_view_(nullptr),
blend_amount_(params.blend_amount),
+ mip_level_(0),
initialized_(false) {
std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
}
@@ -69,12 +73,12 @@ void CNNv2Effect::load_weights() {
size_t weights_size = 0;
const uint8_t* weights_data = (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size);
- if (!weights_data || weights_size < 16) {
+ if (!weights_data || weights_size < 20) {
// Weights not available - effect will skip
return;
}
- // Parse header (16 bytes)
+ // Parse header
const uint32_t* header = (const uint32_t*)weights_data;
uint32_t magic = header[0];
uint32_t version = header[1];
@@ -82,10 +86,20 @@ void CNNv2Effect::load_weights() {
uint32_t total_weights = header[3];
FATAL_CHECK(magic != 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2'
- FATAL_CHECK(version != 1, "Unsupported CNN v2 weights version\n");
+
+ // Support both version 1 (16-byte header) and version 2 (20-byte header with mip_level)
+ if (version == 1) {
+ mip_level_ = 0; // Default for v1
+ } else if (version == 2) {
+ mip_level_ = header[4];
+ } else {
+ FATAL_ERROR("Unsupported CNN v2 weights version: %u\n", version);
+ }
// Parse layer info (20 bytes per layer)
- const uint32_t* layer_data = header + 4;
+ // Offset depends on version: v1=16 bytes (4 u32), v2=20 bytes (5 u32)
+ const uint32_t header_u32_count = (version == 1) ? 4 : 5;
+ const uint32_t* layer_data = header + header_u32_count;
for (uint32_t i = 0; i < num_layers; ++i) {
LayerInfo info;
info.kernel_size = layer_data[i * 5 + 0];
@@ -192,6 +206,13 @@ void CNNv2Effect::create_textures() {
WGPUTextureView view = wgpuTextureCreateView(tex, &view_desc);
layer_views_.push_back(view);
}
+
+ // Create uniform buffer for static feature params
+ WGPUBufferDescriptor params_desc = {};
+ params_desc.size = sizeof(StaticFeatureParams);
+ params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
+ params_desc.mappedAtCreation = false;
+ static_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &params_desc);
}
void CNNv2Effect::create_pipelines() {
@@ -224,8 +245,8 @@ void CNNv2Effect::create_pipelines() {
wgpuShaderModuleRelease(static_module);
// Create bind group layout for static features compute
- // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output
- WGPUBindGroupLayoutEntry bgl_entries[5] = {};
+ // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output, 5=params
+ WGPUBindGroupLayoutEntry bgl_entries[6] = {};
// Binding 0: Input texture (mip 0)
bgl_entries[0].binding = 0;
@@ -258,8 +279,14 @@ void CNNv2Effect::create_pipelines() {
bgl_entries[4].storageTexture.format = WGPUTextureFormat_RGBA32Uint;
bgl_entries[4].storageTexture.viewDimension = WGPUTextureViewDimension_2D;
+ // Binding 5: Params (mip_level)
+ bgl_entries[5].binding = 5;
+ bgl_entries[5].visibility = WGPUShaderStage_Compute;
+ bgl_entries[5].buffer.type = WGPUBufferBindingType_Uniform;
+ bgl_entries[5].buffer.minBindingSize = sizeof(StaticFeatureParams);
+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
- bgl_desc.entryCount = 5;
+ bgl_desc.entryCount = 6;
bgl_desc.entries = bgl_entries;
WGPUBindGroupLayout static_bgl = wgpuDeviceCreateBindGroupLayout(ctx_.device, &bgl_desc);
@@ -378,7 +405,7 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
}
// Create bind group for static features compute
- WGPUBindGroupEntry bg_entries[5] = {};
+ WGPUBindGroupEntry bg_entries[6] = {};
// Binding 0: Input (mip 0)
bg_entries[0].binding = 0;
@@ -386,11 +413,11 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
// Binding 1: Input (mip 1)
bg_entries[1].binding = 1;
- bg_entries[1].textureView = input_mip_view_[0]; // Use mip 0 for now
+ bg_entries[1].textureView = input_mip_view_[0];
// Binding 2: Input (mip 2)
bg_entries[2].binding = 2;
- bg_entries[2].textureView = input_mip_view_[0]; // Use mip 0 for now
+ bg_entries[2].textureView = (input_mip_view_[1]) ? input_mip_view_[1] : input_mip_view_[0];
// Binding 3: Depth (use input for now, no depth available)
bg_entries[3].binding = 3;
@@ -400,9 +427,14 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
bg_entries[4].binding = 4;
bg_entries[4].textureView = static_features_view_;
+ // Binding 5: Params
+ bg_entries[5].binding = 5;
+ bg_entries[5].buffer = static_params_buffer_;
+ bg_entries[5].size = sizeof(StaticFeatureParams);
+
WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = wgpuComputePipelineGetBindGroupLayout(static_pipeline_, 0);
- bg_desc.entryCount = 5;
+ bg_desc.entryCount = 6;
bg_desc.entries = bg_entries;
static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc);
@@ -473,6 +505,14 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder,
effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
}
+ // Update static feature params
+ StaticFeatureParams static_params;
+ static_params.mip_level = mip_level_;
+ static_params.padding[0] = 0;
+ static_params.padding[1] = 0;
+ static_params.padding[2] = 0;
+ wgpuQueueWriteBuffer(ctx_.queue, static_params_buffer_, 0, &static_params, sizeof(static_params));
+
// Pass 1: Compute static features
WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr);
@@ -527,6 +567,7 @@ void CNNv2Effect::cleanup() {
if (static_features_view_) wgpuTextureViewRelease(static_features_view_);
if (static_features_tex_) wgpuTextureRelease(static_features_tex_);
if (static_bind_group_) wgpuBindGroupRelease(static_bind_group_);
+ if (static_params_buffer_) wgpuBufferRelease(static_params_buffer_);
if (static_pipeline_) wgpuComputePipelineRelease(static_pipeline_);
if (layer_pipeline_) wgpuComputePipelineRelease(layer_pipeline_);
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h
index 28aec8c..47dedf5 100644
--- a/src/gpu/effects/cnn_v2_effect.h
+++ b/src/gpu/effects/cnn_v2_effect.h
@@ -47,6 +47,11 @@ private:
float blend_amount;
};
+ struct StaticFeatureParams {
+ uint32_t mip_level;
+ uint32_t padding[3];
+ };
+
void create_textures();
void create_pipelines();
void load_weights();
@@ -55,6 +60,7 @@ private:
// Static features compute
WGPUComputePipeline static_pipeline_;
WGPUBindGroup static_bind_group_;
+ WGPUBuffer static_params_buffer_;
WGPUTexture static_features_tex_;
WGPUTextureView static_features_view_;
@@ -75,5 +81,6 @@ private:
float blend_amount_ = 1.0f;
bool beat_modulated_ = false;
float beat_scale_ = 1.0f;
+ uint32_t mip_level_ = 0;
bool initialized_;
};
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index 9e9e352..1086516 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -16,11 +16,12 @@ def export_weights_binary(checkpoint_path, output_path):
"""Export CNN v2 weights to binary format.
Binary format:
- Header (16 bytes):
+ Header (20 bytes):
uint32 magic ('CNN2')
- uint32 version (1)
+ uint32 version (2)
uint32 num_layers
uint32 total_weights (f16 count)
+ uint32 mip_level (0-3)
LayerInfo × num_layers (20 bytes each):
uint32 kernel_size
@@ -107,19 +108,20 @@ def export_weights_binary(checkpoint_path, output_path):
print(f" Total layers: {len(layers)}")
print(f" Total weights: {len(all_weights_f16)} (f16)")
print(f" Packed: {len(weights_u32)} u32")
- print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes")
+ print(f" Binary size: {20 + len(layers) * 20 + len(weights_u32) * 4} bytes")
# Write binary file
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'wb') as f:
- # Header (16 bytes)
- f.write(struct.pack('<4sIII',
+ # Header (20 bytes) - version 2 with mip_level
+ f.write(struct.pack('<4sIIII',
b'CNN2', # magic
- 1, # version
+ 2, # version (bumped to 2)
len(layers), # num_layers
- len(all_weights_f16))) # total_weights (f16 count)
+ len(all_weights_f16), # total_weights (f16 count)
+ mip_level)) # mip_level
# Layer info (20 bytes per layer)
for layer in layers:
diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
index 7a9e6de..f71fad2 100644
--- a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
+++ b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
@@ -1,13 +1,19 @@
// CNN v2 Static Features Compute Shader
// Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
-// p0-p3: Parametric features (currently RGBD from mip0, could be mip1/2, gradients, etc.)
+// p0-p3: Parametric features from specified mip level (0=mip0, 1=mip1, 2=mip2, 3=mip3)
// Note: Input image RGBD (mip0) fed separately to Layer 0
+struct StaticFeatureParams {
+ mip_level: u32,
+ padding: vec3<u32>,
+}
+
@group(0) @binding(0) var input_tex: texture_2d<f32>;
@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
@group(0) @binding(2) var input_tex_mip2: texture_2d<f32>;
@group(0) @binding(3) var depth_tex: texture_2d<f32>;
@group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(5) var<uniform> params: StaticFeatureParams;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
@@ -18,10 +24,19 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
return;
}
- // Parametric features (p0-p3)
- // TODO: Experiment with mip1 grayscale, Sobel gradients, etc.
- // For now, use RGBD from mip 0 (same as input, but could differ)
- let rgba = textureLoad(input_tex, coord, 0);
+ // Parametric features (p0-p3) - sample from specified mip level
+ var rgba: vec4<f32>;
+ if (params.mip_level == 0u) {
+ rgba = textureLoad(input_tex, coord, 0);
+ } else if (params.mip_level == 1u) {
+ rgba = textureLoad(input_tex_mip1, coord, 0);
+ } else if (params.mip_level == 2u) {
+ rgba = textureLoad(input_tex_mip2, coord, 0);
+ } else {
+ // Mip 3 or higher: use mip 2 as fallback
+ rgba = textureLoad(input_tex_mip2, coord, 0);
+ }
+
let p0 = rgba.r;
let p1 = rgba.g;
let p2 = rgba.b;