diff options
Diffstat (limited to 'src/effects/cnn_effect.cc')
| -rw-r--r-- | src/effects/cnn_effect.cc | 129 |
1 files changed, 0 insertions, 129 deletions
diff --git a/src/effects/cnn_effect.cc b/src/effects/cnn_effect.cc deleted file mode 100644 index 49c5239..0000000 --- a/src/effects/cnn_effect.cc +++ /dev/null @@ -1,129 +0,0 @@ -// CNN post-processing effect implementation -// Neural network-based stylization with modular WGSL - -#include "effects/cnn_effect.h" -#include "gpu/bind_group_builder.h" -#include "gpu/effect.h" -#include "gpu/pipeline_builder.h" -#include "gpu/post_process_helper.h" -#include "gpu/sampler_cache.h" -#include "gpu/shader_composer.h" -#include "gpu/shaders.h" - -// Create custom pipeline with 5 bindings (includes original texture) -static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device, - WGPUTextureFormat format, - const char* shader_code) { - WGPUBindGroupLayout bgl = - BindGroupLayoutBuilder() - .sampler(0, WGPUShaderStage_Fragment) - .texture(1, WGPUShaderStage_Fragment) - .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment) - .uniform(3, WGPUShaderStage_Fragment) - .texture(4, WGPUShaderStage_Fragment) - .build(device); - - WGPURenderPipeline pipeline = RenderPipelineBuilder(device) - .shader(shader_code) - .bind_group_layout(bgl) - .format(format) - .build(); - - wgpuBindGroupLayoutRelease(bgl); - return pipeline; -} - -CNNEffect::CNNEffect(const GpuContext& ctx) - : PostProcessEffect(ctx), layer_index_(0), total_layers_(1), - blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr), - bind_group_(nullptr) { - pipeline_ = - create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); -} - -CNNEffect::CNNEffect(const GpuContext& ctx, const CNNEffectParams& params) - : PostProcessEffect(ctx), layer_index_(params.layer_index), - total_layers_(params.total_layers), blend_amount_(params.blend_amount), - input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) { - pipeline_ = - create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); -} - -void CNNEffect::init(MainSequence* demo) { - PostProcessEffect::init(demo); - demo_ = demo; - params_buffer_.init(ctx_.device); - - // Register auxiliary texture for layer 0 (width_/height_ set by resize()) - if (layer_index_ == 0) { - demo_->register_auxiliary_texture("captured_frame", width_, height_); - } - - // Initialize uniforms BEFORE any bind group creation - uniforms_.update(ctx_.queue, get_common_uniforms()); - - CNNLayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}}; - params_buffer_.update(ctx_.queue, params); -} - -void CNNEffect::resize(int width, int height) { - if (width == width_ && height == height_) - return; - - PostProcessEffect::resize(width, height); - - // Only layer 0 owns the captured_frame texture - if (layer_index_ == 0 && demo_) { - demo_->resize_auxiliary_texture("captured_frame", width, height); - } -} - -void CNNEffect::render(WGPURenderPassEncoder pass, - const CommonPostProcessUniforms& uniforms) { - if (!bind_group_) { - fprintf(stderr, "CNN render: no bind_group\n"); - return; - } - - float effective_blend = blend_amount_; - if (beat_modulated_) { - effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_; - } - - CNNLayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}}; - params_buffer_.update(ctx_.queue, params); - - wgpuRenderPassEncoderSetPipeline(pass, pipeline_); - wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr); - wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0); -} - -void CNNEffect::update_bind_group(WGPUTextureView input_view) { - input_view_ = input_view; - - // Update common uniforms (CRITICAL for UV calculation!) - uniforms_.update(ctx_.queue, get_common_uniforms()); - - // All layers: get captured frame (original input from layer 0) - if (demo_) { - original_view_ = demo_->get_auxiliary_view("captured_frame"); - } - - // Create bind group with original texture - if (bind_group_) - wgpuBindGroupRelease(bind_group_); - - WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0); - // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior - WGPUSampler sampler = - SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp()); - - bind_group_ = - BindGroupBuilder() - .sampler(0, sampler) - .texture(1, input_view_) - .buffer(2, uniforms_.get().buffer, uniforms_.get().size) - .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size) - .texture(4, original_view_ ? original_view_ : input_view_) - .build(ctx_.device, bgl); -} |
