summaryrefslogtreecommitdiff
path: root/src/effects/cnn_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/effects/cnn_effect.cc')
-rw-r--r--src/effects/cnn_effect.cc126
1 files changed, 126 insertions, 0 deletions
diff --git a/src/effects/cnn_effect.cc b/src/effects/cnn_effect.cc
new file mode 100644
index 0000000..4475180
--- /dev/null
+++ b/src/effects/cnn_effect.cc
@@ -0,0 +1,126 @@
+// CNN post-processing effect implementation
+// Neural network-based stylization with modular WGSL
+
+#include "effects/cnn_effect.h"
+#include "gpu/post_process_helper.h"
+#include "gpu/shaders.h"
+#include "gpu/shader_composer.h"
+#include "gpu/effect.h"
+#include "gpu/bind_group_builder.h"
+#include "gpu/sampler_cache.h"
+#include "gpu/pipeline_builder.h"
+
+// Create custom pipeline with 5 bindings (includes original texture)
+static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
+ WGPUTextureFormat format,
+ const char* shader_code) {
+ WGPUBindGroupLayout bgl = BindGroupLayoutBuilder()
+ .sampler(0, WGPUShaderStage_Fragment)
+ .texture(1, WGPUShaderStage_Fragment)
+ .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment)
+ .uniform(3, WGPUShaderStage_Fragment)
+ .texture(4, WGPUShaderStage_Fragment)
+ .build(device);
+
+ WGPURenderPipeline pipeline = RenderPipelineBuilder(device)
+ .shader(shader_code)
+ .bind_group_layout(bgl)
+ .format(format)
+ .build();
+
+ wgpuBindGroupLayoutRelease(bgl);
+ return pipeline;
+}
+
+CNNEffect::CNNEffect(const GpuContext& ctx)
+ : PostProcessEffect(ctx), layer_index_(0), total_layers_(1),
+ blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr),
+ bind_group_(nullptr) {
+ pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
+ cnn_layer_shader_wgsl);
+}
+
+CNNEffect::CNNEffect(const GpuContext& ctx, const CNNEffectParams& params)
+ : PostProcessEffect(ctx), layer_index_(params.layer_index),
+ total_layers_(params.total_layers), blend_amount_(params.blend_amount),
+ input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) {
+ pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
+ cnn_layer_shader_wgsl);
+}
+
+void CNNEffect::init(MainSequence* demo) {
+ PostProcessEffect::init(demo);
+ demo_ = demo;
+ params_buffer_.init(ctx_.device);
+
+ // Register auxiliary texture for layer 0 (width_/height_ set by resize())
+ if (layer_index_ == 0) {
+ demo_->register_auxiliary_texture("captured_frame", width_, height_);
+ }
+
+ // Initialize uniforms BEFORE any bind group creation
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ CNNLayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+}
+
+void CNNEffect::resize(int width, int height) {
+ if (width == width_ && height == height_)
+ return;
+
+ PostProcessEffect::resize(width, height);
+
+ // Only layer 0 owns the captured_frame texture
+ if (layer_index_ == 0 && demo_) {
+ demo_->resize_auxiliary_texture("captured_frame", width, height);
+ }
+}
+
+void CNNEffect::render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) {
+ if (!bind_group_) {
+ fprintf(stderr, "CNN render: no bind_group\n");
+ return;
+ }
+
+ float effective_blend = blend_amount_;
+ if (beat_modulated_) {
+ effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
+ }
+
+ CNNLayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+
+ wgpuRenderPassEncoderSetPipeline(pass, pipeline_);
+ wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr);
+ wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0);
+}
+
+void CNNEffect::update_bind_group(WGPUTextureView input_view) {
+ input_view_ = input_view;
+
+ // Update common uniforms (CRITICAL for UV calculation!)
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ // All layers: get captured frame (original input from layer 0)
+ if (demo_) {
+ original_view_ = demo_->get_auxiliary_view("captured_frame");
+ }
+
+ // Create bind group with original texture
+ if (bind_group_)
+ wgpuBindGroupRelease(bind_group_);
+
+ WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0);
+ // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior
+ WGPUSampler sampler = SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp());
+
+ bind_group_ = BindGroupBuilder()
+ .sampler(0, sampler)
+ .texture(1, input_view_)
+ .buffer(2, uniforms_.get().buffer, uniforms_.get().size)
+ .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size)
+ .texture(4, original_view_ ? original_view_ : input_view_)
+ .build(ctx_.device, bgl);
+}