diff options
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.h')
| -rw-r--r-- | src/gpu/effects/cnn_v2_effect.h | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h index facf4c3..6005cf5 100644 --- a/src/gpu/effects/cnn_v2_effect.h +++ b/src/gpu/effects/cnn_v2_effect.h @@ -19,8 +19,25 @@ public: void update_bind_group(WGPUTextureView input_view) override; private: + struct LayerInfo { + uint32_t kernel_size; + uint32_t in_channels; + uint32_t out_channels; + uint32_t weight_offset; + uint32_t weight_count; + }; + + struct LayerParams { + uint32_t kernel_size; + uint32_t in_channels; + uint32_t out_channels; + uint32_t weight_offset; + uint32_t is_output_layer; + }; + void create_textures(); void create_pipelines(); + void load_weights(); void cleanup(); // Static features compute @@ -29,16 +46,19 @@ private: WGPUTexture static_features_tex_; WGPUTextureView static_features_view_; - // CNN layers (opaque implementation) - std::vector<WGPUComputePipeline> layer_pipelines_; - std::vector<WGPUBindGroup> layer_bind_groups_; - std::vector<WGPUTexture> layer_textures_; + // CNN layers (storage buffer architecture) + WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers + WGPUBuffer weights_buffer_; // Storage buffer for weights + WGPUBuffer layer_params_buffer_; // Uniform buffer for per-layer params + std::vector<LayerInfo> layer_info_; // Layer metadata + std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups + std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers std::vector<WGPUTextureView> layer_views_; // Input mips WGPUTexture input_mip_tex_; WGPUTextureView input_mip_view_[3]; - WGPUTextureView current_input_view_; // Cached input from update_bind_group + WGPUTextureView current_input_view_; bool initialized_; }; |
