summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.h
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.h')
-rw-r--r--cnn_v3/src/cnn_v3_effect.h101
1 files changed, 62 insertions, 39 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h
index 36e2797..070f988 100644
--- a/cnn_v3/src/cnn_v3_effect.h
+++ b/cnn_v3/src/cnn_v3_effect.h
@@ -2,6 +2,13 @@
// Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature
// textures produced by GBufferEffect.
//
+// Architecture: enc_channels=[8,16]
+// enc0: Conv(20→8, 3×3) + FiLM8 + ReLU H×W rgba32uint
+// enc1: Conv(8→16, 3×3) + FiLM16 + ReLU H/2×W/2 2× rgba32uint
+// bottleneck: Conv(16→16, 3×3, dil=2) + ReLU H/4×W/4 2× rgba32uint
+// dec1: Conv(32→8, 3×3) + FiLM8 + ReLU H/2×W/2 rgba32uint
+// dec0: Conv(16→4, 3×3) + FiLM4 + ReLU + sig H×W rgba16float
+//
// Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer)
// Output: output_tex (rgba16float, 4-channel RGBA)
@@ -18,35 +25,19 @@
// Per-pass params uniform layouts (mirror WGSL Params structs exactly)
// ---------------------------------------------------------------------------
-// enc0, dec1, dec0: 4-channel FiLM
+// enc0, dec1: 8-channel FiLM (lo/hi vec4 split)
//
-// WGSL layout (vec3u has align=16, so _pad sits at offset 16):
-// offset 0: weight_offset (u32, 4 bytes)
-// offset 4: (12 bytes implicit padding before vec3u)
-// offset 16: _pad (vec3u, 12 bytes)
-// offset 28: (4 bytes implicit padding before vec4f)
-// offset 32: gamma (vec4f, 16 bytes)
-// offset 48: beta (vec4f, 16 bytes)
-// total: 64 bytes
-struct CnnV3Params4ch {
- uint32_t weight_offset; // offset 0
- uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad)
- float gamma[4]; // offset 32
- float beta[4]; // offset 48
-};
-static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");
-
-// enc1: 8-channel FiLM (split into lo/hi vec4 pairs)
-//
-// WGSL layout (same header padding as above):
-// offset 0: weight_offset (u32, 4 bytes)
-// offset 16: _pad (vec3u, 12 bytes)
-// offset 32: gamma_lo (vec4f, 16 bytes)
-// offset 48: gamma_hi (vec4f, 16 bytes)
-// offset 64: beta_lo (vec4f, 16 bytes)
-// offset 80: beta_hi (vec4f, 16 bytes)
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 4-15: implicit pad, vec3u aligned to 16
+// offset 16: _pad (vec3u, 12 bytes)
+// offset 28-31: implicit pad
+// offset 32: gamma_lo (vec4f)
+// offset 48: gamma_hi (vec4f)
+// offset 64: beta_lo (vec4f)
+// offset 80: beta_hi (vec4f)
// total: 96 bytes
-struct CnnV3ParamsEnc1 {
+struct CnnV3Params8ch {
uint32_t weight_offset; // offset 0
uint32_t _pad[7]; // offsets 4-31
float gamma_lo[4]; // offset 32
@@ -54,10 +45,41 @@ struct CnnV3ParamsEnc1 {
float beta_lo[4]; // offset 64
float beta_hi[4]; // offset 80
};
-static_assert(sizeof(CnnV3ParamsEnc1) == 96,
- "CnnV3ParamsEnc1 must be 96 bytes");
+static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes");
+
+// enc1: 16-channel FiLM (four vec4 groups for gamma + four for beta)
+//
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 16: _pad (vec3u)
+// offset 32: gamma_0..3 (4x vec4f = 64 bytes)
+// offset 96: beta_0..3 (4x vec4f = 64 bytes)
+// total: 160 bytes
+struct CnnV3Params16ch {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31
+ float gamma[16]; // offsets 32-95
+ float beta[16]; // offsets 96-159
+};
+static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes");
+
+// dec0: 4-channel FiLM
+//
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 16: _pad (vec3u)
+// offset 32: gamma (vec4f)
+// offset 48: beta (vec4f)
+// total: 64 bytes
+struct CnnV3Params4ch {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31
+ float gamma[4]; // offset 32
+ float beta[4]; // offset 48
+};
+static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");
-// bottleneck: no FiLM — 4 plain u32s, no alignment gap
+// bottleneck: no FiLM — weight_offset + 3 pads
struct CnnV3ParamsBn {
uint32_t weight_offset;
uint32_t _pad[3];
@@ -90,14 +112,15 @@ class CNNv3Effect : public Effect {
void set_film_params(const CNNv3FiLMParams& fp);
// Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs).
- // Used for testing and inference from trained .bin files.
void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes);
private:
// Intermediate node names (prefixed from output[0])
std::string node_enc0_;
- std::string node_enc1_;
- std::string node_bottleneck_;
+ std::string node_enc1_lo_;
+ std::string node_enc1_hi_;
+ std::string node_bn_lo_;
+ std::string node_bn_hi_;
std::string node_dec1_;
// 5 compute pipelines
@@ -115,20 +138,20 @@ class CNNv3Effect : public Effect {
BindGroup dec0_bg_;
// Params uniform buffers (one per pass)
- UniformBuffer<CnnV3Params4ch> enc0_params_buf_;
- UniformBuffer<CnnV3ParamsEnc1> enc1_params_buf_;
+ UniformBuffer<CnnV3Params8ch> enc0_params_buf_;
+ UniformBuffer<CnnV3Params16ch> enc1_params_buf_;
UniformBuffer<CnnV3ParamsBn> bn_params_buf_;
- UniformBuffer<CnnV3Params4ch> dec1_params_buf_;
+ UniformBuffer<CnnV3Params8ch> dec1_params_buf_;
UniformBuffer<CnnV3Params4ch> dec0_params_buf_;
// Shared packed-f16 weights (storage buffer, read-only in all shaders)
GpuBuffer weights_buf_;
// Per-pass params shadow (updated by set_film_params, uploaded in render)
- CnnV3Params4ch enc0_params_{};
- CnnV3ParamsEnc1 enc1_params_{};
+ CnnV3Params8ch enc0_params_{};
+ CnnV3Params16ch enc1_params_{};
CnnV3ParamsBn bn_params_{};
- CnnV3Params4ch dec1_params_{};
+ CnnV3Params8ch dec1_params_{};
CnnV3Params4ch dec0_params_{};
void create_pipelines();