summaryrefslogtreecommitdiff
path: root/cnn_v2/src/cnn_v2_effect.h
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
committerskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
commit161a59fa50bb92e3664c389fa03b95aefe349b3f (patch)
tree71548f64b2bdea958388f9063b74137659d70306 /cnn_v2/src/cnn_v2_effect.h
parent9c3b72c710bf1ffa7e18f7c7390a425d57487eba (diff)
refactor(cnn): isolate CNN v2 to cnn_v2/ subdirectory
Move all CNN v2 files to dedicated cnn_v2/ directory to prepare for CNN v3 development. Zero functional changes. Structure: - cnn_v2/src/ - C++ effect implementation - cnn_v2/shaders/ - WGSL shaders (6 files) - cnn_v2/weights/ - Binary weights (3 files) - cnn_v2/training/ - Python training scripts (4 files) - cnn_v2/scripts/ - Shell scripts (train_cnn_v2_full.sh) - cnn_v2/tools/ - Validation tools (HTML) - cnn_v2/docs/ - Documentation (4 markdown files) Changes: - Update CMake source list to cnn_v2/src/cnn_v2_effect.cc - Update assets.txt with relative paths to cnn_v2/ - Update includes to ../../cnn_v2/src/cnn_v2_effect.h - Add PROJECT_ROOT resolution to Python/shell scripts - Update doc references in HOWTO.md, TODO.md - Add cnn_v2/README.md Verification: 34/34 tests passing, demo runs correctly. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v2/src/cnn_v2_effect.h')
-rw-r--r--cnn_v2/src/cnn_v2_effect.h89
1 files changed, 89 insertions, 0 deletions
diff --git a/cnn_v2/src/cnn_v2_effect.h b/cnn_v2/src/cnn_v2_effect.h
new file mode 100644
index 0000000..7960b4f
--- /dev/null
+++ b/cnn_v2/src/cnn_v2_effect.h
@@ -0,0 +1,89 @@
+// CNN v2 Effect - Parametric Static Features
+// Multi-pass post-processing with 7D feature input
+// Supports per-layer kernel sizes (e.g., 1×1, 3×3, 5×5)
+
+#pragma once
+#include "gpu/effect.h"
+#include <vector>
+
+struct CNNv2EffectParams {
+ float blend_amount = 1.0f;
+};
+
+class CNNv2Effect : public PostProcessEffect {
+ public:
+ explicit CNNv2Effect(const GpuContext& ctx);
+ explicit CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params);
+ ~CNNv2Effect();
+
+ void init(MainSequence* demo) override;
+ void resize(int width, int height) override;
+ void compute(WGPUCommandEncoder encoder,
+ const CommonPostProcessUniforms& uniforms) override;
+ void render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) override;
+ void update_bind_group(WGPUTextureView input_view) override;
+
+ void set_beat_modulation(bool enabled, float scale = 1.0f) {
+ beat_modulated_ = enabled;
+ beat_scale_ = scale;
+ }
+
+ private:
+ struct LayerInfo {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t weight_count;
+ };
+
+ struct LayerParams {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t is_output_layer;
+ float blend_amount;
+ uint32_t is_layer_0;
+ };
+
+ struct StaticFeatureParams {
+ uint32_t mip_level;
+ uint32_t padding[3];
+ };
+
+ void create_textures();
+ void create_pipelines();
+ void load_weights();
+ void cleanup();
+
+ // Static features compute
+ WGPUComputePipeline static_pipeline_;
+ WGPUBindGroup static_bind_group_;
+ WGPUBuffer static_params_buffer_;
+ WGPUTexture static_features_tex_;
+ WGPUTextureView static_features_view_;
+ WGPUSampler linear_sampler_;
+
+ // CNN layers (storage buffer architecture)
+ WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers
+ WGPUBuffer weights_buffer_; // Storage buffer for weights
+ std::vector<WGPUBuffer>
+ layer_params_buffers_; // Uniform buffers (one per layer)
+ std::vector<LayerInfo> layer_info_; // Layer metadata
+ std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups
+ std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers
+ std::vector<WGPUTextureView> layer_views_;
+
+ // Input mips
+ WGPUTexture input_mip_tex_;
+ WGPUTextureView input_mip_view_[3];
+ WGPUTextureView current_input_view_;
+
+ float blend_amount_ = 1.0f;
+ bool beat_modulated_ = false;
+ float beat_scale_ = 1.0f;
+ uint32_t mip_level_ = 0;
+ bool initialized_;
+};