summaryrefslogtreecommitdiff
path: root/src/gpu/effects/cnn_effect.cc
blob: 25db0c2cacfe503e07e2d77e7087b8ab12b79cc2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// CNN post-processing effect implementation
// Neural network-based stylization with modular WGSL

#include "gpu/effects/cnn_effect.h"
#include "gpu/effects/post_process_helper.h"
#include "gpu/effects/shaders.h"

CNNEffect::CNNEffect(const GpuContext& ctx, int num_layers)
    : PostProcessEffect(ctx), num_layers_(num_layers), input_view_(nullptr),
      bind_group_(nullptr) {
  pipeline_ = create_post_process_pipeline(ctx_.device, ctx_.format,
                                           cnn_layer_shader_wgsl);
}

void CNNEffect::init(MainSequence* demo) {
  PostProcessEffect::init(demo);
  params_buffer_.init(ctx_.device);

  CNNLayerParams params = {0, 1, {0.0f, 0.0f}};
  params_buffer_.update(ctx_.queue, params);
}

void CNNEffect::render(WGPURenderPassEncoder pass, float time, float beat,
                      float intensity, float aspect_ratio) {
  if (!bind_group_) return;

  wgpuRenderPassEncoderSetPipeline(pass, pipeline_);
  wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr);
  wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0);
}

void CNNEffect::update_bind_group(WGPUTextureView input_view) {
  input_view_ = input_view;
  pp_update_bind_group(ctx_.device, pipeline_, &bind_group_,
                      input_view_, uniforms_.get(), params_buffer_.get());
}