summaryrefslogtreecommitdiff
path: root/src/gpu/effects/cnn_v2_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.cc')
-rw-r--r--src/gpu/effects/cnn_v2_effect.cc78
1 files changed, 77 insertions, 1 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc
index 275af68..a49161a 100644
--- a/src/gpu/effects/cnn_v2_effect.cc
+++ b/src/gpu/effects/cnn_v2_effect.cc
@@ -384,6 +384,56 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc);
wgpuBindGroupLayoutRelease(bg_desc.layout);
+
+ // Create layer bind groups
+ if (!layer_pipeline_ || layer_info_.empty()) return;
+
+ // Release old layer bind groups
+ for (auto bg : layer_bind_groups_) {
+ wgpuBindGroupRelease(bg);
+ }
+ layer_bind_groups_.clear();
+
+ // Get bind group layout from layer pipeline
+ WGPUBindGroupLayout layer_bgl = wgpuComputePipelineGetBindGroupLayout(layer_pipeline_, 0);
+
+ // Create bind group for each layer
+ for (size_t i = 0; i < layer_info_.size(); ++i) {
+ WGPUBindGroupEntry layer_entries[5] = {};
+
+ // Binding 0: Static features (constant)
+ layer_entries[0].binding = 0;
+ layer_entries[0].textureView = static_features_view_;
+
+ // Binding 1: Layer input (ping-pong: use previous layer's output)
+ // First layer uses static features as input, others use ping-pong buffers
+ layer_entries[1].binding = 1;
+ layer_entries[1].textureView = (i == 0) ? static_features_view_ : layer_views_[i % 2];
+
+ // Binding 2: Output texture (ping-pong)
+ layer_entries[2].binding = 2;
+ layer_entries[2].textureView = layer_views_[(i + 1) % 2];
+
+ // Binding 3: Weights buffer (constant)
+ layer_entries[3].binding = 3;
+ layer_entries[3].buffer = weights_buffer_;
+ layer_entries[3].size = wgpuBufferGetSize(weights_buffer_);
+
+ // Binding 4: Layer params (will be updated per dispatch)
+ layer_entries[4].binding = 4;
+ layer_entries[4].buffer = layer_params_buffer_;
+ layer_entries[4].size = sizeof(LayerParams);
+
+ WGPUBindGroupDescriptor layer_bg_desc = {};
+ layer_bg_desc.layout = layer_bgl;
+ layer_bg_desc.entryCount = 5;
+ layer_bg_desc.entries = layer_entries;
+
+ WGPUBindGroup layer_bg = wgpuDeviceCreateBindGroup(ctx_.device, &layer_bg_desc);
+ layer_bind_groups_.push_back(layer_bg);
+ }
+
+ wgpuBindGroupLayoutRelease(layer_bgl);
}
void CNNv2Effect::compute(WGPUCommandEncoder encoder,
@@ -405,7 +455,33 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder,
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
- // TODO: Execute CNN layer passes
+ // Execute CNN layer passes
+ if (!layer_pipeline_ || layer_bind_groups_.empty()) return;
+
+ for (size_t i = 0; i < layer_info_.size(); ++i) {
+ const LayerInfo& info = layer_info_[i];
+
+ // Update layer params uniform buffer
+ LayerParams params;
+ params.kernel_size = info.kernel_size;
+ params.in_channels = info.in_channels;
+ params.out_channels = info.out_channels;
+ params.weight_offset = info.weight_offset;
+ params.is_output_layer = (i == layer_info_.size() - 1) ? 1 : 0;
+
+ wgpuQueueWriteBuffer(ctx_.queue, layer_params_buffer_, 0, &params, sizeof(params));
+
+ // Execute layer compute pass
+ WGPUComputePassEncoder layer_pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr);
+
+ wgpuComputePassEncoderSetPipeline(layer_pass, layer_pipeline_);
+ wgpuComputePassEncoderSetBindGroup(layer_pass, 0, layer_bind_groups_[i], 0, nullptr);
+
+ wgpuComputePassEncoderDispatchWorkgroups(layer_pass, workgroups_x, workgroups_y, 1);
+
+ wgpuComputePassEncoderEnd(layer_pass);
+ wgpuComputePassEncoderRelease(layer_pass);
+ }
}
void CNNv2Effect::render(WGPURenderPassEncoder pass,