diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.cc | 78 |
1 files changed, 77 insertions, 1 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc index 275af68..a49161a 100644 --- a/src/gpu/effects/cnn_v2_effect.cc +++ b/src/gpu/effects/cnn_v2_effect.cc @@ -384,6 +384,56 @@ void CNNv2Effect::update_bind_group(WGPUTextureView input_view) { static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc); wgpuBindGroupLayoutRelease(bg_desc.layout); + + // Create layer bind groups + if (!layer_pipeline_ || layer_info_.empty()) return; + + // Release old layer bind groups + for (auto bg : layer_bind_groups_) { + wgpuBindGroupRelease(bg); + } + layer_bind_groups_.clear(); + + // Get bind group layout from layer pipeline + WGPUBindGroupLayout layer_bgl = wgpuComputePipelineGetBindGroupLayout(layer_pipeline_, 0); + + // Create bind group for each layer + for (size_t i = 0; i < layer_info_.size(); ++i) { + WGPUBindGroupEntry layer_entries[5] = {}; + + // Binding 0: Static features (constant) + layer_entries[0].binding = 0; + layer_entries[0].textureView = static_features_view_; + + // Binding 1: Layer input (ping-pong: use previous layer's output) + // First layer uses static features as input, others use ping-pong buffers + layer_entries[1].binding = 1; + layer_entries[1].textureView = (i == 0) ? static_features_view_ : layer_views_[i % 2]; + + // Binding 2: Output texture (ping-pong) + layer_entries[2].binding = 2; + layer_entries[2].textureView = layer_views_[(i + 1) % 2]; + + // Binding 3: Weights buffer (constant) + layer_entries[3].binding = 3; + layer_entries[3].buffer = weights_buffer_; + layer_entries[3].size = wgpuBufferGetSize(weights_buffer_); + + // Binding 4: Layer params (will be updated per dispatch) + layer_entries[4].binding = 4; + layer_entries[4].buffer = layer_params_buffer_; + layer_entries[4].size = sizeof(LayerParams); + + WGPUBindGroupDescriptor layer_bg_desc = {}; + layer_bg_desc.layout = layer_bgl; + layer_bg_desc.entryCount = 5; + layer_bg_desc.entries = layer_entries; + + WGPUBindGroup layer_bg = wgpuDeviceCreateBindGroup(ctx_.device, &layer_bg_desc); + layer_bind_groups_.push_back(layer_bg); + } + + wgpuBindGroupLayoutRelease(layer_bgl); } void CNNv2Effect::compute(WGPUCommandEncoder encoder, @@ -405,7 +455,33 @@ void CNNv2Effect::compute(WGPUCommandEncoder encoder, wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); - // TODO: Execute CNN layer passes + // Execute CNN layer passes + if (!layer_pipeline_ || layer_bind_groups_.empty()) return; + + for (size_t i = 0; i < layer_info_.size(); ++i) { + const LayerInfo& info = layer_info_[i]; + + // Update layer params uniform buffer + LayerParams params; + params.kernel_size = info.kernel_size; + params.in_channels = info.in_channels; + params.out_channels = info.out_channels; + params.weight_offset = info.weight_offset; + params.is_output_layer = (i == layer_info_.size() - 1) ? 1 : 0; + + wgpuQueueWriteBuffer(ctx_.queue, layer_params_buffer_, 0, ¶ms, sizeof(params)); + + // Execute layer compute pass + WGPUComputePassEncoder layer_pass = wgpuCommandEncoderBeginComputePass(encoder, nullptr); + + wgpuComputePassEncoderSetPipeline(layer_pass, layer_pipeline_); + wgpuComputePassEncoderSetBindGroup(layer_pass, 0, layer_bind_groups_[i], 0, nullptr); + + wgpuComputePassEncoderDispatchWorkgroups(layer_pass, workgroups_x, workgroups_y, 1); + + wgpuComputePassEncoderEnd(layer_pass); + wgpuComputePassEncoderRelease(layer_pass); + } } void CNNv2Effect::render(WGPURenderPassEncoder pass, |
