summaryrefslogtreecommitdiff
path: root/src/gpu/effects/cnn_v2_effect.h
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 11:34:50 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 11:34:50 +0100
commit91d42f2d057e077c267d6775cc109a801aa315c0 (patch)
tree18cd67c9ce11f24149e6dafa65d176ca7143fcbb /src/gpu/effects/cnn_v2_effect.h
parent301db1f29137d3db7828e7a0103986cc845b7672 (diff)
CNN v2: parametric static features - Phases 1-4
Infrastructure for enhanced CNN post-processing with 7D feature input. Phase 1: Shaders - Static features compute (RGBD + UV + sin10_x + bias → 8×f16) - Layer template (convolution skeleton, packing/unpacking) - 3 mip level support for multi-scale features Phase 2: C++ Effect - CNNv2Effect class (multi-pass architecture) - Texture management (static features, layer buffers) - Build integration (CMakeLists, assets, tests) Phase 3: Training Pipeline - train_cnn_v2.py: PyTorch model with static feature concatenation - export_cnn_v2_shader.py: f32→f16 quantization, WGSL generation - Configurable architecture (kernels, channels) Phase 4: Validation - validate_cnn_v2.sh: End-to-end pipeline - Checkpoint → shaders → build → test images Tests: 36/36 passing Next: Complete render pipeline implementation (bind groups, multi-pass) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/gpu/effects/cnn_v2_effect.h')
-rw-r--r--src/gpu/effects/cnn_v2_effect.h41
1 files changed, 41 insertions, 0 deletions
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h
new file mode 100644
index 0000000..edf301e
--- /dev/null
+++ b/src/gpu/effects/cnn_v2_effect.h
@@ -0,0 +1,41 @@
+// CNN v2 Effect - Parametric Static Features
+// Multi-pass post-processing with 7D feature input
+
+#pragma once
+#include "gpu/effect.h"
+#include <vector>
+
+class CNNv2Effect : public PostProcessEffect {
+public:
+ explicit CNNv2Effect(const GpuContext& ctx);
+ ~CNNv2Effect();
+
+ void init(MainSequence* demo) override;
+ void resize(int width, int height) override;
+ void render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) override;
+ void update_bind_group(WGPUTextureView input_view) override;
+
+private:
+ void create_textures();
+ void create_pipelines();
+ void cleanup();
+
+ // Static features compute
+ WGPUComputePipeline static_pipeline_;
+ WGPUBindGroup static_bind_group_;
+ WGPUTexture static_features_tex_;
+ WGPUTextureView static_features_view_;
+
+ // CNN layers (opaque implementation)
+ std::vector<WGPUComputePipeline> layer_pipelines_;
+ std::vector<WGPUBindGroup> layer_bind_groups_;
+ std::vector<WGPUTexture> layer_textures_;
+ std::vector<WGPUTextureView> layer_views_;
+
+ // Input mips
+ WGPUTexture input_mip_tex_;
+ WGPUTextureView input_mip_view_[3];
+
+ bool initialized_;
+};