summaryrefslogtreecommitdiff
path: root/src/effects/cnn_effect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/effects/cnn_effect.cc')
-rw-r--r--src/effects/cnn_effect.cc65
1 files changed, 34 insertions, 31 deletions
diff --git a/src/effects/cnn_effect.cc b/src/effects/cnn_effect.cc
index 4475180..49c5239 100644
--- a/src/effects/cnn_effect.cc
+++ b/src/effects/cnn_effect.cc
@@ -2,31 +2,32 @@
// Neural network-based stylization with modular WGSL
#include "effects/cnn_effect.h"
-#include "gpu/post_process_helper.h"
-#include "gpu/shaders.h"
-#include "gpu/shader_composer.h"
-#include "gpu/effect.h"
#include "gpu/bind_group_builder.h"
-#include "gpu/sampler_cache.h"
+#include "gpu/effect.h"
#include "gpu/pipeline_builder.h"
+#include "gpu/post_process_helper.h"
+#include "gpu/sampler_cache.h"
+#include "gpu/shader_composer.h"
+#include "gpu/shaders.h"
// Create custom pipeline with 5 bindings (includes original texture)
static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
- WGPUTextureFormat format,
- const char* shader_code) {
- WGPUBindGroupLayout bgl = BindGroupLayoutBuilder()
- .sampler(0, WGPUShaderStage_Fragment)
- .texture(1, WGPUShaderStage_Fragment)
- .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment)
- .uniform(3, WGPUShaderStage_Fragment)
- .texture(4, WGPUShaderStage_Fragment)
- .build(device);
+ WGPUTextureFormat format,
+ const char* shader_code) {
+ WGPUBindGroupLayout bgl =
+ BindGroupLayoutBuilder()
+ .sampler(0, WGPUShaderStage_Fragment)
+ .texture(1, WGPUShaderStage_Fragment)
+ .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment)
+ .uniform(3, WGPUShaderStage_Fragment)
+ .texture(4, WGPUShaderStage_Fragment)
+ .build(device);
WGPURenderPipeline pipeline = RenderPipelineBuilder(device)
- .shader(shader_code)
- .bind_group_layout(bgl)
- .format(format)
- .build();
+ .shader(shader_code)
+ .bind_group_layout(bgl)
+ .format(format)
+ .build();
wgpuBindGroupLayoutRelease(bgl);
return pipeline;
@@ -36,16 +37,16 @@ CNNEffect::CNNEffect(const GpuContext& ctx)
: PostProcessEffect(ctx), layer_index_(0), total_layers_(1),
blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr),
bind_group_(nullptr) {
- pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
- cnn_layer_shader_wgsl);
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
}
CNNEffect::CNNEffect(const GpuContext& ctx, const CNNEffectParams& params)
: PostProcessEffect(ctx), layer_index_(params.layer_index),
total_layers_(params.total_layers), blend_amount_(params.blend_amount),
input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) {
- pipeline_ = create_cnn_pipeline(ctx_.device, ctx_.format,
- cnn_layer_shader_wgsl);
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
}
void CNNEffect::init(MainSequence* demo) {
@@ -78,7 +79,7 @@ void CNNEffect::resize(int width, int height) {
}
void CNNEffect::render(WGPURenderPassEncoder pass,
- const CommonPostProcessUniforms& uniforms) {
+ const CommonPostProcessUniforms& uniforms) {
if (!bind_group_) {
fprintf(stderr, "CNN render: no bind_group\n");
return;
@@ -114,13 +115,15 @@ void CNNEffect::update_bind_group(WGPUTextureView input_view) {
WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0);
// Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior
- WGPUSampler sampler = SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp());
+ WGPUSampler sampler =
+ SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp());
- bind_group_ = BindGroupBuilder()
- .sampler(0, sampler)
- .texture(1, input_view_)
- .buffer(2, uniforms_.get().buffer, uniforms_.get().size)
- .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size)
- .texture(4, original_view_ ? original_view_ : input_view_)
- .build(ctx_.device, bgl);
+ bind_group_ =
+ BindGroupBuilder()
+ .sampler(0, sampler)
+ .texture(1, input_view_)
+ .buffer(2, uniforms_.get().buffer, uniforms_.get().size)
+ .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size)
+ .texture(4, original_view_ ? original_view_ : input_view_)
+ .build(ctx_.device, bgl);
}