summaryrefslogtreecommitdiff
path: root/cnn_v1/src/cnn_v1_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v1/src/cnn_v1_effect.cc')
-rw-r--r--cnn_v1/src/cnn_v1_effect.cc129
1 files changed, 129 insertions, 0 deletions
diff --git a/cnn_v1/src/cnn_v1_effect.cc b/cnn_v1/src/cnn_v1_effect.cc
new file mode 100644
index 0000000..1f44619
--- /dev/null
+++ b/cnn_v1/src/cnn_v1_effect.cc
@@ -0,0 +1,129 @@
+// CNN post-processing effect implementation
+// Neural network-based stylization with modular WGSL
+
+#include "cnn_v1_effect.h"
+#include "gpu/bind_group_builder.h"
+#include "gpu/effect.h"
+#include "gpu/pipeline_builder.h"
+#include "gpu/post_process_helper.h"
+#include "gpu/sampler_cache.h"
+#include "gpu/shader_composer.h"
+#include "gpu/shaders.h"
+
+// Create custom pipeline with 5 bindings (includes original texture)
+static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
+ WGPUTextureFormat format,
+ const char* shader_code) {
+ WGPUBindGroupLayout bgl =
+ BindGroupLayoutBuilder()
+ .sampler(0, WGPUShaderStage_Fragment)
+ .texture(1, WGPUShaderStage_Fragment)
+ .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment)
+ .uniform(3, WGPUShaderStage_Fragment)
+ .texture(4, WGPUShaderStage_Fragment)
+ .build(device);
+
+ WGPURenderPipeline pipeline = RenderPipelineBuilder(device)
+ .shader(shader_code)
+ .bind_group_layout(bgl)
+ .format(format)
+ .build();
+
+ wgpuBindGroupLayoutRelease(bgl);
+ return pipeline;
+}
+
+CNNv1Effect::CNNv1Effect(const GpuContext& ctx)
+ : PostProcessEffect(ctx), layer_index_(0), total_layers_(1),
+ blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr),
+ bind_group_(nullptr) {
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
+}
+
+CNNv1Effect::CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params)
+ : PostProcessEffect(ctx), layer_index_(params.layer_index),
+ total_layers_(params.total_layers), blend_amount_(params.blend_amount),
+ input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) {
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
+}
+
+void CNNv1Effect::init(MainSequence* demo) {
+ PostProcessEffect::init(demo);
+ demo_ = demo;
+ params_buffer_.init(ctx_.device);
+
+ // Register auxiliary texture for layer 0 (width_/height_ set by resize())
+ if (layer_index_ == 0) {
+ demo_->register_auxiliary_texture("captured_frame", width_, height_);
+ }
+
+ // Initialize uniforms BEFORE any bind group creation
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ CNNv1LayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+}
+
+void CNNv1Effect::resize(int width, int height) {
+ if (width == width_ && height == height_)
+ return;
+
+ PostProcessEffect::resize(width, height);
+
+ // Only layer 0 owns the captured_frame texture
+ if (layer_index_ == 0 && demo_) {
+ demo_->resize_auxiliary_texture("captured_frame", width, height);
+ }
+}
+
+void CNNv1Effect::render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) {
+ if (!bind_group_) {
+ fprintf(stderr, "CNN render: no bind_group\n");
+ return;
+ }
+
+ float effective_blend = blend_amount_;
+ if (beat_modulated_) {
+ effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
+ }
+
+ CNNv1LayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+
+ wgpuRenderPassEncoderSetPipeline(pass, pipeline_);
+ wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr);
+ wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0);
+}
+
+void CNNv1Effect::update_bind_group(WGPUTextureView input_view) {
+ input_view_ = input_view;
+
+ // Update common uniforms (CRITICAL for UV calculation!)
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ // All layers: get captured frame (original input from layer 0)
+ if (demo_) {
+ original_view_ = demo_->get_auxiliary_view("captured_frame");
+ }
+
+ // Create bind group with original texture
+ if (bind_group_)
+ wgpuBindGroupRelease(bind_group_);
+
+ WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0);
+ // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior
+ WGPUSampler sampler =
+ SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp());
+
+ bind_group_ =
+ BindGroupBuilder()
+ .sampler(0, sampler)
+ .texture(1, input_view_)
+ .buffer(2, uniforms_.get().buffer, uniforms_.get().size)
+ .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size)
+ .texture(4, original_view_ ? original_view_ : input_view_)
+ .build(ctx_.device, bgl);
+}