diff options
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.h')
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.h | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h new file mode 100644 index 0000000..6005cf5 --- /dev/null +++ b/src/gpu/effects/cnn_v2_effect.h @@ -0,0 +1,64 @@ +// CNN v2 Effect - Parametric Static Features +// Multi-pass post-processing with 7D feature input + +#pragma once +#include "gpu/effect.h" +#include <vector> + +class CNNv2Effect : public PostProcessEffect { +public: + explicit CNNv2Effect(const GpuContext& ctx); + ~CNNv2Effect(); + + void init(MainSequence* demo) override; + void resize(int width, int height) override; + void compute(WGPUCommandEncoder encoder, + const CommonPostProcessUniforms& uniforms) override; + void render(WGPURenderPassEncoder pass, + const CommonPostProcessUniforms& uniforms) override; + void update_bind_group(WGPUTextureView input_view) override; + +private: + struct LayerInfo { + uint32_t kernel_size; + uint32_t in_channels; + uint32_t out_channels; + uint32_t weight_offset; + uint32_t weight_count; + }; + + struct LayerParams { + uint32_t kernel_size; + uint32_t in_channels; + uint32_t out_channels; + uint32_t weight_offset; + uint32_t is_output_layer; + }; + + void create_textures(); + void create_pipelines(); + void load_weights(); + void cleanup(); + + // Static features compute + WGPUComputePipeline static_pipeline_; + WGPUBindGroup static_bind_group_; + WGPUTexture static_features_tex_; + WGPUTextureView static_features_view_; + + // CNN layers (storage buffer architecture) + WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers + WGPUBuffer weights_buffer_; // Storage buffer for weights + WGPUBuffer layer_params_buffer_; // Uniform buffer for per-layer params + std::vector<LayerInfo> layer_info_; // Layer metadata + std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups + std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers + std::vector<WGPUTextureView> layer_views_; + + // Input mips + WGPUTexture input_mip_tex_; + WGPUTextureView input_mip_view_[3]; + WGPUTextureView current_input_view_; + + bool initialized_; +}; |
