diff options
Diffstat (limited to 'cnn_v1/src/cnn_v1_effect.cc')
| -rw-r--r-- | cnn_v1/src/cnn_v1_effect.cc | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/cnn_v1/src/cnn_v1_effect.cc b/cnn_v1/src/cnn_v1_effect.cc new file mode 100644 index 0000000..1f44619 --- /dev/null +++ b/cnn_v1/src/cnn_v1_effect.cc @@ -0,0 +1,129 @@ +// CNN post-processing effect implementation +// Neural network-based stylization with modular WGSL + +#include "cnn_v1_effect.h" +#include "gpu/bind_group_builder.h" +#include "gpu/effect.h" +#include "gpu/pipeline_builder.h" +#include "gpu/post_process_helper.h" +#include "gpu/sampler_cache.h" +#include "gpu/shader_composer.h" +#include "gpu/shaders.h" + +// Create custom pipeline with 5 bindings (includes original texture) +static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device, + WGPUTextureFormat format, + const char* shader_code) { + WGPUBindGroupLayout bgl = + BindGroupLayoutBuilder() + .sampler(0, WGPUShaderStage_Fragment) + .texture(1, WGPUShaderStage_Fragment) + .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment) + .uniform(3, WGPUShaderStage_Fragment) + .texture(4, WGPUShaderStage_Fragment) + .build(device); + + WGPURenderPipeline pipeline = RenderPipelineBuilder(device) + .shader(shader_code) + .bind_group_layout(bgl) + .format(format) + .build(); + + wgpuBindGroupLayoutRelease(bgl); + return pipeline; +} + +CNNv1Effect::CNNv1Effect(const GpuContext& ctx) + : PostProcessEffect(ctx), layer_index_(0), total_layers_(1), + blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr), + bind_group_(nullptr) { + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); +} + +CNNv1Effect::CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params) + : PostProcessEffect(ctx), layer_index_(params.layer_index), + total_layers_(params.total_layers), blend_amount_(params.blend_amount), + input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) { + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); +} + +void CNNv1Effect::init(MainSequence* demo) { + PostProcessEffect::init(demo); + demo_ = demo; + params_buffer_.init(ctx_.device); + + // Register auxiliary texture for layer 0 (width_/height_ set by resize()) + if (layer_index_ == 0) { + demo_->register_auxiliary_texture("captured_frame", width_, height_); + } + + // Initialize uniforms BEFORE any bind group creation + uniforms_.update(ctx_.queue, get_common_uniforms()); + + CNNv1LayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}}; + params_buffer_.update(ctx_.queue, params); +} + +void CNNv1Effect::resize(int width, int height) { + if (width == width_ && height == height_) + return; + + PostProcessEffect::resize(width, height); + + // Only layer 0 owns the captured_frame texture + if (layer_index_ == 0 && demo_) { + demo_->resize_auxiliary_texture("captured_frame", width, height); + } +} + +void CNNv1Effect::render(WGPURenderPassEncoder pass, + const CommonPostProcessUniforms& uniforms) { + if (!bind_group_) { + fprintf(stderr, "CNN render: no bind_group\n"); + return; + } + + float effective_blend = blend_amount_; + if (beat_modulated_) { + effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_; + } + + CNNv1LayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}}; + params_buffer_.update(ctx_.queue, params); + + wgpuRenderPassEncoderSetPipeline(pass, pipeline_); + wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr); + wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0); +} + +void CNNv1Effect::update_bind_group(WGPUTextureView input_view) { + input_view_ = input_view; + + // Update common uniforms (CRITICAL for UV calculation!) + uniforms_.update(ctx_.queue, get_common_uniforms()); + + // All layers: get captured frame (original input from layer 0) + if (demo_) { + original_view_ = demo_->get_auxiliary_view("captured_frame"); + } + + // Create bind group with original texture + if (bind_group_) + wgpuBindGroupRelease(bind_group_); + + WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0); + // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior + WGPUSampler sampler = + SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp()); + + bind_group_ = + BindGroupBuilder() + .sampler(0, sampler) + .texture(1, input_view_) + .buffer(2, uniforms_.get().buffer, uniforms_.get().size) + .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size) + .texture(4, original_view_ ? original_view_ : input_view_) + .build(ctx_.device, bgl); +} |
