diff options
Diffstat (limited to 'src/effects/cnn_effect.cc')
| -rw-r--r-- | src/effects/cnn_effect.cc | 65 |
1 files changed, 34 insertions, 31 deletions
diff --git a/src/effects/cnn_effect.cc b/src/effects/cnn_effect.cc index 4475180..49c5239 100644 --- a/src/effects/cnn_effect.cc +++ b/src/effects/cnn_effect.cc @@ -2,31 +2,32 @@ // Neural network-based stylization with modular WGSL #include "effects/cnn_effect.h" -#include "gpu/post_process_helper.h" -#include "gpu/shaders.h" -#include "gpu/shader_composer.h" -#include "gpu/effect.h" #include "gpu/bind_group_builder.h" -#include "gpu/sampler_cache.h" +#include "gpu/effect.h" #include "gpu/pipeline_builder.h" +#include "gpu/post_process_helper.h" +#include "gpu/sampler_cache.h" +#include "gpu/shader_composer.h" +#include "gpu/shaders.h" // Create custom pipeline with 5 bindings (includes original texture) static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device, - WGPUTextureFormat format, - const char* shader_code) { - WGPUBindGroupLayout bgl = BindGroupLayoutBuilder() - .sampler(0, WGPUShaderStage_Fragment) - .texture(1, WGPUShaderStage_Fragment) - .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment) - .uniform(3, WGPUShaderStage_Fragment) - .texture(4, WGPUShaderStage_Fragment) - .build(device); + WGPUTextureFormat format, + const char* shader_code) { + WGPUBindGroupLayout bgl = + BindGroupLayoutBuilder() + .sampler(0, WGPUShaderStage_Fragment) + .texture(1, WGPUShaderStage_Fragment) + .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment) + .uniform(3, WGPUShaderStage_Fragment) + .texture(4, WGPUShaderStage_Fragment) + .build(device); WGPURenderPipeline pipeline = RenderPipelineBuilder(device) - .shader(shader_code) - .bind_group_layout(bgl) - .format(format) - .build(); + .shader(shader_code) + .bind_group_layout(bgl) + .format(format) + .build(); wgpuBindGroupLayoutRelease(bgl); return pipeline; @@ -36,16 +37,16 @@ CNNEffect::CNNEffect(const GpuContext& ctx) : PostProcessEffect(ctx), layer_index_(0), total_layers_(1), blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) { - pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format, - cnn_layer_shader_wgsl); + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); } CNNEffect::CNNEffect(const GpuContext& ctx, const CNNEffectParams& params) : PostProcessEffect(ctx), layer_index_(params.layer_index), total_layers_(params.total_layers), blend_amount_(params.blend_amount), input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) { - pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format, - cnn_layer_shader_wgsl); + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); } void CNNEffect::init(MainSequence* demo) { @@ -78,7 +79,7 @@ void CNNEffect::resize(int width, int height) { } void CNNEffect::render(WGPURenderPassEncoder pass, - const CommonPostProcessUniforms& uniforms) { + const CommonPostProcessUniforms& uniforms) { if (!bind_group_) { fprintf(stderr, "CNN render: no bind_group\n"); return; @@ -114,13 +115,15 @@ void CNNEffect::update_bind_group(WGPUTextureView input_view) { WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0); // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior - WGPUSampler sampler = SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp()); + WGPUSampler sampler = + SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp()); - bind_group_ = BindGroupBuilder() - .sampler(0, sampler) - .texture(1, input_view_) - .buffer(2, uniforms_.get().buffer, uniforms_.get().size) - .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size) - .texture(4, original_view_ ? original_view_ : input_view_) - .build(ctx_.device, bgl); + bind_group_ = + BindGroupBuilder() + .sampler(0, sampler) + .texture(1, input_view_) + .buffer(2, uniforms_.get().buffer, uniforms_.get().size) + .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size) + .texture(4, original_view_ ? original_view_ : input_view_) + .build(ctx_.device, bgl); } |
