summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.h
blob: 589680c597c8daeb22d232cc43a36cb96e450c38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// CNN v3 Effect — U-Net + FiLM inference pass
// Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature
// textures produced by GBufferEffect.
//
// Architecture: enc_channels=[8,16]
//   enc0:       Conv(20→8,  3×3) + FiLM8  + ReLU       H×W      rgba32uint
//   enc1:       Conv(8→16,  3×3) + FiLM16 + ReLU       H/2×W/2  2× rgba32uint
//   bottleneck: Conv(16→16, 3×3, dil=2)   + ReLU       H/4×W/4  2× rgba32uint
//   dec1:       Conv(32→8,  3×3) + FiLM8  + ReLU       H/2×W/2  rgba32uint
//   dec0:       Conv(16→4,  3×3) + FiLM4  + sig   H×W      rgba16float
//
// Inputs:  feat_tex0, feat_tex1   (rgba32uint, 20-channel G-buffer)
// Output:  output_tex             (rgba16float, 4-channel RGBA)

#pragma once

#include <cstdint>

#include "gpu/effect.h"
#include "gpu/sequence.h"
#include "gpu/uniform_helper.h"
#include "gpu/wgpu_resource.h"

// ---------------------------------------------------------------------------
// Per-pass params uniform layouts (mirror WGSL Params structs exactly)
// ---------------------------------------------------------------------------

// enc0, dec1: 8-channel FiLM (lo/hi vec4 split)
//
// WGSL layout:
//   offset  0: weight_offset (u32)
//   offset  4-15: implicit pad, vec3u aligned to 16
//   offset 16: _pad (vec3u, 12 bytes)
//   offset 28-31: implicit pad
//   offset 32: gamma_lo (vec4f)
//   offset 48: gamma_hi (vec4f)
//   offset 64: beta_lo  (vec4f)
//   offset 80: beta_hi  (vec4f)
//   total: 96 bytes
struct CnnV3Params8ch {
  uint32_t weight_offset;  // offset 0
  uint32_t _pad[7];        // offsets 4-31
  float gamma_lo[4];       // offset 32
  float gamma_hi[4];       // offset 48
  float beta_lo[4];        // offset 64
  float beta_hi[4];        // offset 80
};
static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes");

// enc1: 16-channel FiLM (four vec4 groups for gamma + four for beta)
//
// WGSL layout:
//   offset  0: weight_offset (u32)
//   offset 16: _pad (vec3u)
//   offset 32: gamma_0..3 (4x vec4f = 64 bytes)
//   offset 96: beta_0..3  (4x vec4f = 64 bytes)
//   total: 160 bytes
struct CnnV3Params16ch {
  uint32_t weight_offset;  // offset 0
  uint32_t _pad[7];        // offsets 4-31
  float gamma[16];         // offsets 32-95
  float beta[16];          // offsets 96-159
};
static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes");

// dec0: 4-channel FiLM
//
// WGSL layout:
//   offset  0: weight_offset (u32)
//   offset 16: _pad (vec3u)
//   offset 32: gamma (vec4f)
//   offset 48: beta  (vec4f)
//   total: 64 bytes
struct CnnV3Params4ch {
  uint32_t weight_offset;  // offset 0
  uint32_t _pad[7];        // offsets 4-31
  float gamma[4];          // offset 32
  float beta[4];           // offset 48
};
static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");

// bottleneck: no FiLM — weight_offset + 3 pads
struct CnnV3ParamsBn {
  uint32_t weight_offset;
  uint32_t _pad[3];
};
static_assert(sizeof(CnnV3ParamsBn) == 16, "CnnV3ParamsBn must be 16 bytes");

// ---------------------------------------------------------------------------
// FiLM conditioning inputs (CPU-side, uploaded via set_film_params each frame)
// ---------------------------------------------------------------------------
struct CNNv3FiLMParams {
  float beat_phase      = 0.0f;  // 0-1 within current beat
  float beat_norm       = 0.0f;  // beat_time / 8.0, normalized 8-beat cycle
  float audio_intensity = 0.0f;  // peak audio level 0-1
  float style_p0        = 0.0f;  // user-defined style param
  float style_p1        = 0.0f;  // user-defined style param
};

// FiLM MLP weights: Linear(5→16)→ReLU→Linear(16→72).
// Loaded from cnn_v3_film_mlp.bin (1320 f32 = 5280 bytes).
// Layout: l0_w(80) | l0_b(16) | l1_w(1152) | l1_b(72), all row-major f32.
struct CNNv3FilmMlp {
  float l0_w[16 * 5];   // (16, 5) row-major
  float l0_b[16];
  float l1_w[72 * 16];  // (72, 16) row-major
  float l1_b[72];
};
static_assert(sizeof(CNNv3FilmMlp) == 1320 * 4, "CNNv3FilmMlp size mismatch");

class CNNv3Effect : public Effect {
 public:
  CNNv3Effect(const GpuContext& ctx, const std::vector<std::string>& inputs,
              const std::vector<std::string>& outputs, float start_time,
              float end_time);

  void declare_nodes(NodeRegistry& registry) override;

  void render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params,
              NodeRegistry& nodes) override;

  // Update FiLM conditioning; call before render() each frame.
  void set_film_params(const CNNv3FiLMParams& fp);

  // Upload packed-f16 conv weights (kWeightsBufBytes bytes of u32 pairs).
  void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes);

  // Load FiLM MLP weights from cnn_v3_film_mlp.bin (1320 f32 = 5280 bytes).
  // Must be called before set_film_params() for learned conditioning.
  void load_film_mlp(const void* data, uint32_t size_bytes);

 private:
  // Intermediate node names (prefixed from output[0])
  std::string node_enc0_;
  std::string node_enc1_lo_;
  std::string node_enc1_hi_;
  std::string node_bn_lo_;
  std::string node_bn_hi_;
  std::string node_dec1_;

  // 5 compute pipelines
  ComputePipeline enc0_pipeline_;
  ComputePipeline enc1_pipeline_;
  ComputePipeline bn_pipeline_;
  ComputePipeline dec1_pipeline_;
  ComputePipeline dec0_pipeline_;

  // 5 bind groups (rebuilt each render since node views may change)
  BindGroup enc0_bg_;
  BindGroup enc1_bg_;
  BindGroup bn_bg_;
  BindGroup dec1_bg_;
  BindGroup dec0_bg_;

  // Params uniform buffers (one per pass)
  UniformBuffer<CnnV3Params8ch>  enc0_params_buf_;
  UniformBuffer<CnnV3Params16ch> enc1_params_buf_;
  UniformBuffer<CnnV3ParamsBn>   bn_params_buf_;
  UniformBuffer<CnnV3Params8ch>  dec1_params_buf_;
  UniformBuffer<CnnV3Params4ch>  dec0_params_buf_;

  // Shared packed-f16 weights (storage buffer, read-only in all shaders)
  GpuBuffer weights_buf_;

  // Per-pass params shadow (updated by set_film_params, uploaded in render)
  CnnV3Params8ch  enc0_params_{};
  CnnV3Params16ch enc1_params_{};
  CnnV3ParamsBn   bn_params_{};
  CnnV3Params8ch  dec1_params_{};
  CnnV3Params4ch  dec0_params_{};

  void create_pipelines();
  void update_bind_groups(NodeRegistry& nodes);

  CNNv3FilmMlp mlp_{};
  bool mlp_loaded_ = false;
};