summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gpu/effects/cnn_v2_effect.cc63
-rw-r--r--src/gpu/effects/cnn_v2_effect.h7
2 files changed, 59 insertions, 11 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc
index 9c727ba..97e4790 100644
--- a/src/gpu/effects/cnn_v2_effect.cc
+++ b/src/gpu/effects/cnn_v2_effect.cc
@@ -16,6 +16,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
: PostProcessEffect(ctx),
static_pipeline_(nullptr),
static_bind_group_(nullptr),
+ static_params_buffer_(nullptr),
static_features_tex_(nullptr),
static_features_view_(nullptr),
layer_pipeline_(nullptr),
@@ -23,6 +24,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
input_mip_tex_(nullptr),
current_input_view_(nullptr),
blend_amount_(1.0f),
+ mip_level_(0),
initialized_(false) {
std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
}
@@ -31,6 +33,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params)
: PostProcessEffect(ctx),
static_pipeline_(nullptr),
static_bind_group_(nullptr),
+ static_params_buffer_(nullptr),
static_features_tex_(nullptr),
static_features_view_(nullptr),
layer_pipeline_(nullptr),
@@ -38,6 +41,7 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params)
input_mip_tex_(nullptr),
current_input_view_(nullptr),
blend_amount_(params.blend_amount),
+ mip_level_(0),
initialized_(false) {
std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
}
@@ -69,12 +73,12 @@ void CNNv2Effect::load_weights() {
size_t weights_size = 0;
const uint8_t* weights_data = (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size);
- if (!weights_data || weights_size < 16) {
+ if (!weights_data || weights_size < 20) {
// Weights not available - effect will skip
return;
}
- // Parse header (16 bytes)
+ // Parse header
const uint32_t* header = (const uint32_t*)weights_data;
uint32_t magic = header[0];
uint32_t version = header[1];
@@ -82,10 +86,20 @@ void CNNv2Effect::load_weights() {
uint32_t total_weights = header[3];
FATAL_CHECK(magic != 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2'
- FATAL_CHECK(version != 1, "Unsupported CNN v2 weights version\n");
+
+ // Support both version 1 (16-byte header) and version 2 (20-byte header with mip_level)
+ if (version == 1) {
+ mip_level_ = 0; // Default for v1
+ } else if (version == 2) {
+ mip_level_ = header[4];
+ } else {
+ FATAL_ERROR("Unsupported CNN v2 weights version: %u\n", version);
+ }
// Parse layer info (20 bytes per layer)
- const uint32_t* layer_data = header + 4;
+ // Offset depends on version: v1=16 bytes (4 u32), v2=20 bytes (5 u32)
+ const uint32_t header_u32_count = (version == 1) ? 4 : 5;
+ const uint32_t* layer_data = header + header_u32_count;
for (uint32_t i = 0; i < num_layers; ++i) {
LayerInfo info;
info.kernel_size = layer_data[i * 5 + 0];
@@ -192,6 +206,13 @@ void CNNv2Effect::create_textures() {
WGPUTextureView view = wgpuTextureCreateView(tex, &view_desc);
layer_views_.push_back(view);
}
+
+ // Create uniform buffer for static feature params
+ WGPUBufferDescriptor params_desc = {};
+ params_desc.size = sizeof(StaticFeatureParams);
+ params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
+ params_desc.mappedAtCreation = false;
+ static_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &params_desc);
}
void CNNv2Effect::create_pipelines() {
@@ -224,8 +245,8 @@ void CNNv2Effect::create_pipelines() {
wgpuShaderModuleRelease(static_module);
// Create bind group layout for static features compute
- // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output
- WGPUBindGroupLayoutEntry bgl_entries[5] = {};
+ // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output, 5=params
+ WGPUBindGroupLayoutEntry bgl_entries[6] = {};
// Binding 0: Input texture (mip 0)
bgl_entries[0].binding = 0;
@@ -258,8 +279,14 @@ void CNNv2Effect::create_pipelines() {
bgl_entries[4].storageTexture.format = WGPUTextureFormat_RGBA32Uint;
bgl_entries[4].storageTexture.viewDimension = WGPUTextureViewDimension_2D;
+ // Binding 5: Params (mip_level)
+ bgl_entries[5].binding = 5;
+ bgl_entries[5].visibility = WGPUShaderStage_Compute;
+ bgl_entries[5].buffer.type = WGPUBufferBindingType_Uniform;
+ bgl_entries[5].buffer.minBindingSize = sizeof(StaticFeatureParams);
+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
- bgl_desc.entryCount = 5;
+ bgl_desc.entryCount = 6;
bgl_desc.entries = bgl_entries;
WGPUBindGroupLayout static_bgl = wgpuDeviceCreateBindGroupLayout(ctx_.device, &bgl_desc);
@@ -378,7 +405,7 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
}
// Create bind group for static features compute
- WGPUBindGroupEntry bg_entries[5] = {};
+ WGPUBindGroupEntry bg_entries[6] = {};
// Binding 0: Input (mip 0)
bg_entries[0].binding = 0;
@@ -386,11 +413,11 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
// Binding 1: Input (mip 1)
bg_entries[1].binding = 1;
- bg_entries[1].textureView = input_mip_view_[0]; // Use mip 0 for now
+ bg_entries[1].textureView = input_mip_view_[0];
// Binding 2: Input (mip 2)
bg_entries[2].binding = 2;
- bg_entries[2].textureView = input_mip_view_[0]; // Use mip 0 for now
+ bg_entries[2].textureView = (input_mip_view_[1]) ? input_mip_view_[1] : input_mip_view_[0];
// Binding 3: Depth (use input for now, no depth available)
bg_entries[3].binding = 3;
@@ -400,9 +427,14 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
bg_entries[4].binding = 4;
bg_entries[4].textureView = static_features_view_;
+ // Binding 5: Params
+ bg_entries[5].binding = 5;
+ bg_entries[5].buffer = static_params_buffer_;
+ bg_entries[5].size = sizeof(StaticFeatureParams);
+
WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = wgpuComputePipelineGetBindGroupLayout(static_pipeline_, 0);
- bg_desc.entryCount = 5;
+ bg_desc.entryCount = 6;
bg_desc.entries = bg_entries;
static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc);
@@ -473,6 +505,14 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder,
effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
}
+ // Update static feature params
+ StaticFeatureParams static_params;
+ static_params.mip_level = mip_level_;
+ static_params.padding[0] = 0;
+ static_params.padding[1] = 0;
+ static_params.padding[2] = 0;
+ wgpuQueueWriteBuffer(ctx_.queue, static_params_buffer_, 0, &static_params, sizeof(static_params));
+
// Pass 1: Compute static features
WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr);
@@ -527,6 +567,7 @@ void CNNv2Effect::cleanup() {
if (static_features_view_) wgpuTextureViewRelease(static_features_view_);
if (static_features_tex_) wgpuTextureRelease(static_features_tex_);
if (static_bind_group_) wgpuBindGroupRelease(static_bind_group_);
+ if (static_params_buffer_) wgpuBufferRelease(static_params_buffer_);
if (static_pipeline_) wgpuComputePipelineRelease(static_pipeline_);
if (layer_pipeline_) wgpuComputePipelineRelease(layer_pipeline_);
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h
index 28aec8c..47dedf5 100644
--- a/src/gpu/effects/cnn_v2_effect.h
+++ b/src/gpu/effects/cnn_v2_effect.h
@@ -47,6 +47,11 @@ private:
float blend_amount;
};
+ struct StaticFeatureParams {
+ uint32_t mip_level;
+ uint32_t padding[3];
+ };
+
void create_textures();
void create_pipelines();
void load_weights();
@@ -55,6 +60,7 @@ private:
// Static features compute
WGPUComputePipeline static_pipeline_;
WGPUBindGroup static_bind_group_;
+ WGPUBuffer static_params_buffer_;
WGPUTexture static_features_tex_;
WGPUTextureView static_features_view_;
@@ -75,5 +81,6 @@ private:
float blend_amount_ = 1.0f;
bool beat_modulated_ = false;
float beat_scale_ = 1.0f;
+ uint32_t mip_level_ = 0;
bool initialized_;
};