diff options
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.h')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.h | 101 |
1 files changed, 62 insertions, 39 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h index 36e2797..070f988 100644 --- a/cnn_v3/src/cnn_v3_effect.h +++ b/cnn_v3/src/cnn_v3_effect.h @@ -2,6 +2,13 @@ // Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature // textures produced by GBufferEffect. // +// Architecture: enc_channels=[8,16] +// enc0: Conv(20→8, 3×3) + FiLM8 + ReLU H×W rgba32uint +// enc1: Conv(8→16, 3×3) + FiLM16 + ReLU H/2×W/2 2× rgba32uint +// bottleneck: Conv(16→16, 3×3, dil=2) + ReLU H/4×W/4 2× rgba32uint +// dec1: Conv(32→8, 3×3) + FiLM8 + ReLU H/2×W/2 rgba32uint +// dec0: Conv(16→4, 3×3) + FiLM4 + ReLU + sig H×W rgba16float +// // Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer) // Output: output_tex (rgba16float, 4-channel RGBA) @@ -18,35 +25,19 @@ // Per-pass params uniform layouts (mirror WGSL Params structs exactly) // --------------------------------------------------------------------------- -// enc0, dec1, dec0: 4-channel FiLM +// enc0, dec1: 8-channel FiLM (lo/hi vec4 split) // -// WGSL layout (vec3u has align=16, so _pad sits at offset 16): -// offset 0: weight_offset (u32, 4 bytes) -// offset 4: (12 bytes implicit padding before vec3u) -// offset 16: _pad (vec3u, 12 bytes) -// offset 28: (4 bytes implicit padding before vec4f) -// offset 32: gamma (vec4f, 16 bytes) -// offset 48: beta (vec4f, 16 bytes) -// total: 64 bytes -struct CnnV3Params4ch { - uint32_t weight_offset; // offset 0 - uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad) - float gamma[4]; // offset 32 - float beta[4]; // offset 48 -}; -static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes"); - -// enc1: 8-channel FiLM (split into lo/hi vec4 pairs) -// -// WGSL layout (same header padding as above): -// offset 0: weight_offset (u32, 4 bytes) -// offset 16: _pad (vec3u, 12 bytes) -// offset 32: gamma_lo (vec4f, 16 bytes) -// offset 48: gamma_hi (vec4f, 16 bytes) -// offset 64: beta_lo (vec4f, 16 bytes) -// offset 80: beta_hi (vec4f, 16 bytes) +// WGSL layout: +// offset 0: weight_offset (u32) +// offset 4-15: implicit pad, vec3u aligned to 16 +// offset 16: _pad (vec3u, 12 bytes) +// offset 28-31: implicit pad +// offset 32: gamma_lo (vec4f) +// offset 48: gamma_hi (vec4f) +// offset 64: beta_lo (vec4f) +// offset 80: beta_hi (vec4f) // total: 96 bytes -struct CnnV3ParamsEnc1 { +struct CnnV3Params8ch { uint32_t weight_offset; // offset 0 uint32_t _pad[7]; // offsets 4-31 float gamma_lo[4]; // offset 32 @@ -54,10 +45,41 @@ struct CnnV3ParamsEnc1 { float beta_lo[4]; // offset 64 float beta_hi[4]; // offset 80 }; -static_assert(sizeof(CnnV3ParamsEnc1) == 96, - "CnnV3ParamsEnc1 must be 96 bytes"); +static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes"); + +// enc1: 16-channel FiLM (four vec4 groups for gamma + four for beta) +// +// WGSL layout: +// offset 0: weight_offset (u32) +// offset 16: _pad (vec3u) +// offset 32: gamma_0..3 (4x vec4f = 64 bytes) +// offset 96: beta_0..3 (4x vec4f = 64 bytes) +// total: 160 bytes +struct CnnV3Params16ch { + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma[16]; // offsets 32-95 + float beta[16]; // offsets 96-159 +}; +static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes"); + +// dec0: 4-channel FiLM +// +// WGSL layout: +// offset 0: weight_offset (u32) +// offset 16: _pad (vec3u) +// offset 32: gamma (vec4f) +// offset 48: beta (vec4f) +// total: 64 bytes +struct CnnV3Params4ch { + uint32_t weight_offset; // offset 0 + uint32_t _pad[7]; // offsets 4-31 + float gamma[4]; // offset 32 + float beta[4]; // offset 48 +}; +static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes"); -// bottleneck: no FiLM — 4 plain u32s, no alignment gap +// bottleneck: no FiLM — weight_offset + 3 pads struct CnnV3ParamsBn { uint32_t weight_offset; uint32_t _pad[3]; @@ -90,14 +112,15 @@ class CNNv3Effect : public Effect { void set_film_params(const CNNv3FiLMParams& fp); // Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs). - // Used for testing and inference from trained .bin files. void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes); private: // Intermediate node names (prefixed from output[0]) std::string node_enc0_; - std::string node_enc1_; - std::string node_bottleneck_; + std::string node_enc1_lo_; + std::string node_enc1_hi_; + std::string node_bn_lo_; + std::string node_bn_hi_; std::string node_dec1_; // 5 compute pipelines @@ -115,20 +138,20 @@ class CNNv3Effect : public Effect { BindGroup dec0_bg_; // Params uniform buffers (one per pass) - UniformBuffer<CnnV3Params4ch> enc0_params_buf_; - UniformBuffer<CnnV3ParamsEnc1> enc1_params_buf_; + UniformBuffer<CnnV3Params8ch> enc0_params_buf_; + UniformBuffer<CnnV3Params16ch> enc1_params_buf_; UniformBuffer<CnnV3ParamsBn> bn_params_buf_; - UniformBuffer<CnnV3Params4ch> dec1_params_buf_; + UniformBuffer<CnnV3Params8ch> dec1_params_buf_; UniformBuffer<CnnV3Params4ch> dec0_params_buf_; // Shared packed-f16 weights (storage buffer, read-only in all shaders) GpuBuffer weights_buf_; // Per-pass params shadow (updated by set_film_params, uploaded in render) - CnnV3Params4ch enc0_params_{}; - CnnV3ParamsEnc1 enc1_params_{}; + CnnV3Params8ch enc0_params_{}; + CnnV3Params16ch enc1_params_{}; CnnV3ParamsBn bn_params_{}; - CnnV3Params4ch dec1_params_{}; + CnnV3Params8ch dec1_params_{}; CnnV3Params4ch dec0_params_{}; void create_pipelines(); |
