summaryrefslogtreecommitdiff
path: root/src/gpu/effects/cnn_effect.cc
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 12:48:43 +0100
commit6944733a6a2f05c18e7e0b73f847a4c9144801fd (patch)
tree10713cd41a0e038a016a2e6b357471690f232834 /src/gpu/effects/cnn_effect.cc
parentcc9cbeb75353181193e3afb880dc890aa8bf8985 (diff)
feat: Add multi-layer CNN support with framebuffer capture and blend control
Implements automatic layer chaining and generic framebuffer capture API for multi-layer neural network effects with proper original input preservation. Key changes: - Effect::needs_framebuffer_capture() - generic API for pre-render capture - MainSequence: auto-capture to "captured_frame" auxiliary texture - CNNEffect: multi-layer support via layer_index/total_layers params - seq_compiler: expands "layers=N" to N chained effect instances - Shader: @binding(4) original_input available to all layers - Training: generates layer switches and original input binding - Blend: mix(original, result, blend_amount) uses layer 0 input Timeline syntax: CNNEffect layers=3 blend=0.7 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/gpu/effects/cnn_effect.cc')
-rw-r--r--src/gpu/effects/cnn_effect.cc130
1 files changed, 123 insertions, 7 deletions
diff --git a/src/gpu/effects/cnn_effect.cc b/src/gpu/effects/cnn_effect.cc
index 25db0c2..f5d0a51 100644
--- a/src/gpu/effects/cnn_effect.cc
+++ b/src/gpu/effects/cnn_effect.cc
@@ -4,19 +4,101 @@
#include "gpu/effects/cnn_effect.h"
#include "gpu/effects/post_process_helper.h"
#include "gpu/effects/shaders.h"
+#include "gpu/effects/shader_composer.h"
+#include "gpu/effect.h"
-CNNEffect::CNNEffect(const GpuContext& ctx, int num_layers)
- : PostProcessEffect(ctx), num_layers_(num_layers), input_view_(nullptr),
+// Create custom pipeline with 5 bindings (includes original texture)
+static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
+ WGPUTextureFormat format,
+ const char* shader_code) {
+ std::string composed_shader = ShaderComposer::Get().Compose({}, shader_code);
+
+ WGPUShaderModuleDescriptor shader_desc = {};
+ WGPUShaderSourceWGSL wgsl_src = {};
+ wgsl_src.chain.sType = WGPUSType_ShaderSourceWGSL;
+ wgsl_src.code = str_view(composed_shader.c_str());
+ shader_desc.nextInChain = &wgsl_src.chain;
+ WGPUShaderModule shader_module =
+ wgpuDeviceCreateShaderModule(device, &shader_desc);
+
+ WGPUBindGroupLayoutEntry bgl_entries[5] = {};
+ bgl_entries[0].binding = 0; // sampler
+ bgl_entries[0].visibility = WGPUShaderStage_Fragment;
+ bgl_entries[0].sampler.type = WGPUSamplerBindingType_Filtering;
+ bgl_entries[1].binding = 1; // input texture
+ bgl_entries[1].visibility = WGPUShaderStage_Fragment;
+ bgl_entries[1].texture.sampleType = WGPUTextureSampleType_Float;
+ bgl_entries[1].texture.viewDimension = WGPUTextureViewDimension_2D;
+ bgl_entries[2].binding = 2; // uniforms
+ bgl_entries[2].visibility = WGPUShaderStage_Vertex | WGPUShaderStage_Fragment;
+ bgl_entries[2].buffer.type = WGPUBufferBindingType_Uniform;
+ bgl_entries[3].binding = 3; // effect params
+ bgl_entries[3].visibility = WGPUShaderStage_Fragment;
+ bgl_entries[3].buffer.type = WGPUBufferBindingType_Uniform;
+ bgl_entries[4].binding = 4; // original texture
+ bgl_entries[4].visibility = WGPUShaderStage_Fragment;
+ bgl_entries[4].texture.sampleType = WGPUTextureSampleType_Float;
+ bgl_entries[4].texture.viewDimension = WGPUTextureViewDimension_2D;
+
+ WGPUBindGroupLayoutDescriptor bgl_desc = {};
+ bgl_desc.entryCount = 5;
+ bgl_desc.entries = bgl_entries;
+ WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
+
+ WGPUPipelineLayoutDescriptor pl_desc = {};
+ pl_desc.bindGroupLayoutCount = 1;
+ pl_desc.bindGroupLayouts = &bgl;
+ WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc);
+
+ WGPUColorTargetState color_target = {};
+ color_target.format = format;
+ color_target.writeMask = WGPUColorWriteMask_All;
+
+ WGPUFragmentState fragment_state = {};
+ fragment_state.module = shader_module;
+ fragment_state.entryPoint = str_view("fs_main");
+ fragment_state.targetCount = 1;
+ fragment_state.targets = &color_target;
+
+ WGPURenderPipelineDescriptor pipeline_desc = {};
+ pipeline_desc.layout = pl;
+ pipeline_desc.vertex.module = shader_module;
+ pipeline_desc.vertex.entryPoint = str_view("vs_main");
+ pipeline_desc.fragment = &fragment_state;
+ pipeline_desc.primitive.topology = WGPUPrimitiveTopology_TriangleList;
+ pipeline_desc.multisample.count = 1;
+ pipeline_desc.multisample.mask = 0xFFFFFFFF;
+
+ return wgpuDeviceCreateRenderPipeline(device, &pipeline_desc);
+}
+
+CNNEffect::CNNEffect(const GpuContext& ctx)
+ : PostProcessEffect(ctx), layer_index_(0), total_layers_(1),
+ blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr),
bind_group_(nullptr) {
- pipeline_ = create_post_process_pipeline(ctx_.device, ctx_.format,
- cnn_layer_shader_wgsl);
+ pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
+ cnn_layer_shader_wgsl);
+}
+
+CNNEffect::CNNEffect(const GpuContext& ctx, const CNNEffectParams& params)
+ : PostProcessEffect(ctx), layer_index_(params.layer_index),
+ total_layers_(params.total_layers), blend_amount_(params.blend_amount),
+ input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) {
+ pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
+ cnn_layer_shader_wgsl);
}
void CNNEffect::init(MainSequence* demo) {
PostProcessEffect::init(demo);
+ demo_ = demo;
params_buffer_.init(ctx_.device);
- CNNLayerParams params = {0, 1, {0.0f, 0.0f}};
+ // Register captured_frame texture (used by all layers for original input)
+ if (layer_index_ == 0) {
+ demo_->register_auxiliary_texture("captured_frame", width_, height_);
+ }
+
+ CNNLayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}};
params_buffer_.update(ctx_.queue, params);
}
@@ -31,6 +113,40 @@ void CNNEffect::render(WGPURenderPassEncoder pass, float time, float beat,
void CNNEffect::update_bind_group(WGPUTextureView input_view) {
input_view_ = input_view;
- pp_update_bind_group(ctx_.device, pipeline_, &bind_group_,
- input_view_, uniforms_.get(), params_buffer_.get());
+
+ // All layers: get captured frame (original input from layer 0)
+ if (demo_) {
+ original_view_ = demo_->get_auxiliary_view("captured_frame");
+ }
+
+ // Create bind group with original texture
+ if (bind_group_)
+ wgpuBindGroupRelease(bind_group_);
+
+ WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0);
+ WGPUSamplerDescriptor sd = {};
+ sd.magFilter = WGPUFilterMode_Linear;
+ sd.minFilter = WGPUFilterMode_Linear;
+ sd.maxAnisotropy = 1;
+ WGPUSampler sampler = wgpuDeviceCreateSampler(ctx_.device, &sd);
+
+ WGPUBindGroupEntry bge[5] = {};
+ bge[0].binding = 0;
+ bge[0].sampler = sampler;
+ bge[1].binding = 1;
+ bge[1].textureView = input_view_;
+ bge[2].binding = 2;
+ bge[2].buffer = uniforms_.get().buffer;
+ bge[2].size = uniforms_.get().size;
+ bge[3].binding = 3;
+ bge[3].buffer = params_buffer_.get().buffer;
+ bge[3].size = params_buffer_.get().size;
+ bge[4].binding = 4;
+ bge[4].textureView = original_view_ ? original_view_ : input_view_; // Fallback
+
+ WGPUBindGroupDescriptor bgd = {};
+ bgd.layout = bgl;
+ bgd.entryCount = 5;
+ bgd.entries = bge;
+ bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bgd);
}