summaryrefslogtreecommitdiff
path: root/src/gpu/effects/cnn_v2_effect.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.h')
-rw-r--r--src/gpu/effects/cnn_v2_effect.h30
1 files changed, 25 insertions, 5 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h
index facf4c3..6005cf5 100644
--- a/src/gpu/effects/cnn_v2_effect.h
+++ b/src/gpu/effects/cnn_v2_effect.h
@@ -19,8 +19,25 @@ public:
void update_bind_group(WGPUTextureView input_view) override;
private:
+ struct LayerInfo {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t weight_count;
+ };
+
+ struct LayerParams {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t is_output_layer;
+ };
+
void create_textures();
void create_pipelines();
+ void load_weights();
void cleanup();
// Static features compute
@@ -29,16 +46,19 @@ private:
WGPUTexture static_features_tex_;
WGPUTextureView static_features_view_;
- // CNN layers (opaque implementation)
- std::vector<WGPUComputePipeline> layer_pipelines_;
- std::vector<WGPUBindGroup> layer_bind_groups_;
- std::vector<WGPUTexture> layer_textures_;
+ // CNN layers (storage buffer architecture)
+ WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers
+ WGPUBuffer weights_buffer_; // Storage buffer for weights
+ WGPUBuffer layer_params_buffer_; // Uniform buffer for per-layer params
+ std::vector<LayerInfo> layer_info_; // Layer metadata
+ std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups
+ std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers
std::vector<WGPUTextureView> layer_views_;
// Input mips
WGPUTexture input_mip_tex_;
WGPUTextureView input_mip_view_[3];
- WGPUTextureView current_input_view_; // Cached input from update_bind_group
+ WGPUTextureView current_input_view_;
bool initialized_;
};