summaryrefslogtreecommitdiff
path: root/cnn_v3/src
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
commit8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch)
tree2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/src
parent4ca498277b033ae10134045dae9c8c249a8d2b2b (diff)
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/src')
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc247
-rw-r--r--cnn_v3/src/cnn_v3_effect.h101
2 files changed, 188 insertions, 160 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index 1391eba..dc26751 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -1,5 +1,5 @@
// CNN v3 Effect — U-Net + FiLM inference (5 compute passes)
-// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md §7 for shader details.
+// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md for shader details.
#include "cnn_v3_effect.h"
@@ -17,17 +17,16 @@
#include <cstring>
// ---------------------------------------------------------------------------
-// Weight layout constants — explicit formulas matching WGSL shader comments
-// ---------------------------------------------------------------------------
+// Weight layout constants — enc_channels=[8,16]
//
// Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases
-// Layout: OIHW order (out × in × kH × kW), biases appended after conv weights
-//
-static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias
-static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias
-static const uint32_t kBnWeights = 8 * 8 * 9 + 8; // Conv(8→8,3×3,dilation=2)+bias
-static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias
-static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias
+// Layout: OIHW order (out × in × kH × kW), biases appended
+// ---------------------------------------------------------------------------
+static const uint32_t kEnc0Weights = 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448
+static const uint32_t kEnc1Weights = 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168
+static const uint32_t kBnWeights = 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320
+static const uint32_t kDec1Weights = 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312
+static const uint32_t kDec0Weights = 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580
static const uint32_t kEnc0Offset = 0;
static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights;
@@ -35,13 +34,12 @@ static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights;
static const uint32_t kDec1Offset = kBnOffset + kBnWeights;
static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights;
static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights;
+// = 1448 + 1168 + 2320 + 2312 + 580 = 7828 f16
-// Weights buffer size in bytes: f16 values are packed two-per-u32.
-// Round up to u32 boundary.
static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4;
// ---------------------------------------------------------------------------
-// Shader source externs (registered in shaders.cc via InitShaderComposer)
+// Shader source externs
// ---------------------------------------------------------------------------
extern const char* cnn_v3_enc0_wgsl;
extern const char* cnn_v3_enc1_wgsl;
@@ -103,14 +101,6 @@ static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) {
e.texture.viewDimension = WGPUTextureViewDimension_2D;
return e;
}
-static WGPUBindGroupLayoutEntry bgl_float_tex(uint32_t binding) {
- WGPUBindGroupLayoutEntry e = {};
- e.binding = binding;
- e.visibility = WGPUShaderStage_Compute;
- e.texture.sampleType = WGPUTextureSampleType_Float;
- e.texture.viewDimension = WGPUTextureViewDimension_2D;
- return e;
-}
static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) {
WGPUBindGroupLayoutEntry e = {};
e.binding = binding;
@@ -151,45 +141,46 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
const std::string& prefix =
outputs.empty() ? std::string("cnn_v3") : outputs[0];
- node_enc0_ = prefix + "_enc0";
- node_enc1_ = prefix + "_enc1";
- node_bottleneck_ = prefix + "_bottleneck";
- node_dec1_ = prefix + "_dec1";
+ node_enc0_ = prefix + "_enc0";
+ node_enc1_lo_ = prefix + "_enc1_lo";
+ node_enc1_hi_ = prefix + "_enc1_hi";
+ node_bn_lo_ = prefix + "_bn_lo";
+ node_bn_hi_ = prefix + "_bn_hi";
+ node_dec1_ = prefix + "_dec1";
- // Allocate zeroed weights buffer (f16 pairs packed as u32).
- // Weights are zero-initialized; load_weights() can fill from file later.
weights_buf_ = gpu_create_buffer(
ctx_.device, kWeightsBufBytes,
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst);
- // Initialize uniform buffers.
enc0_params_buf_.init(ctx_.device);
enc1_params_buf_.init(ctx_.device);
bn_params_buf_.init(ctx_.device);
dec1_params_buf_.init(ctx_.device);
dec0_params_buf_.init(ctx_.device);
- // Set weight offsets (FiLM γ/β default to identity: γ=1, β=0).
+ // Set weight offsets; FiLM γ/β default to identity (γ=1, β=0).
enc0_params_.weight_offset = kEnc0Offset;
- for (int i = 0; i < 4; ++i) { enc0_params_.gamma[i] = 1.0f; }
-
- enc1_params_.weight_offset = kEnc1Offset;
for (int i = 0; i < 4; ++i) {
- enc1_params_.gamma_lo[i] = 1.0f;
- enc1_params_.gamma_hi[i] = 1.0f;
+ enc0_params_.gamma_lo[i] = 1.0f;
+ enc0_params_.gamma_hi[i] = 1.0f;
}
+ enc1_params_.weight_offset = kEnc1Offset;
+ for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f; }
+
bn_params_.weight_offset = kBnOffset;
dec1_params_.weight_offset = kDec1Offset;
- for (int i = 0; i < 4; ++i) { dec1_params_.gamma[i] = 1.0f; }
+ for (int i = 0; i < 4; ++i) {
+ dec1_params_.gamma_lo[i] = 1.0f;
+ dec1_params_.gamma_hi[i] = 1.0f;
+ }
dec0_params_.weight_offset = kDec0Offset;
for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; }
create_pipelines();
- // Load trained weights from asset system (zero-initialized if absent).
size_t weights_size = 0;
const void* weights_data =
GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3, &weights_size);
@@ -206,20 +197,21 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
const int W = registry.default_width();
const int H = registry.default_height();
- // enc0_tex: rgba16float full-res
- registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, W, H);
- // enc1_tex: rgba32uint half-res — shaders use textureDimensions() for bounds
- registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
- // bottleneck_tex: rgba32uint quarter-res
- registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
- // dec1_tex: rgba16float half-res
- registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2);
+ // enc0: rgba32uint full-res (8ch packed f16)
+ registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H);
+ // enc1: two rgba32uint half-res (8ch each = 16ch total)
+ registry.declare_node(node_enc1_lo_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
+ registry.declare_node(node_enc1_hi_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
+ // bottleneck: two rgba32uint quarter-res (8ch each = 16ch total)
+ registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
+ registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
+ // dec1: rgba32uint half-res (8ch packed f16)
+ registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
// output_nodes_[0]: rgba16float full-res — declared externally by caller
}
// ---------------------------------------------------------------------------
-// set_film_params — simple linear mapping (placeholder, no MLP yet)
-// TODO(phase-7): replace with CPU forward pass through cnn_v3_film_mlp.bin
+// upload_weights / set_film_params
// ---------------------------------------------------------------------------
void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
@@ -228,26 +220,26 @@ void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
}
void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
- // Identity + audio/beat modulation.
- // Replace with FiLM MLP output once training is done.
const float a = fp.audio_intensity;
const float b = fp.beat_phase;
for (int i = 0; i < 4; ++i) {
- enc0_params_.gamma[i] = 1.0f + a * 0.5f;
- enc0_params_.beta[i] = b * 0.1f;
+ enc0_params_.gamma_lo[i] = 1.0f + a * 0.5f;
+ enc0_params_.gamma_hi[i] = 1.0f + a * 0.5f;
+ enc0_params_.beta_lo[i] = b * 0.1f;
+ enc0_params_.beta_hi[i] = b * 0.1f;
}
- for (int i = 0; i < 4; ++i) {
- enc1_params_.gamma_lo[i] = 1.0f + a * 0.3f;
- enc1_params_.gamma_hi[i] = 1.0f + a * 0.3f;
- enc1_params_.beta_lo[i] = fp.beat_norm * 0.1f;
- enc1_params_.beta_hi[i] = fp.beat_norm * 0.1f;
+ for (int i = 0; i < 16; ++i) {
+ enc1_params_.gamma[i] = 1.0f + a * 0.3f;
+ enc1_params_.beta[i] = fp.beat_norm * 0.1f;
}
for (int i = 0; i < 4; ++i) {
- dec1_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f;
- dec1_params_.beta[i] = fp.style_p1 * 0.1f;
- dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f;
- dec0_params_.beta[i] = fp.style_p1 * 0.1f;
+ dec1_params_.gamma_lo[i] = 1.0f + fp.style_p0 * 0.5f;
+ dec1_params_.gamma_hi[i] = 1.0f + fp.style_p0 * 0.5f;
+ dec1_params_.beta_lo[i] = fp.style_p1 * 0.1f;
+ dec1_params_.beta_hi[i] = fp.style_p1 * 0.1f;
+ dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f;
+ dec0_params_.beta[i] = fp.style_p1 * 0.1f;
}
}
@@ -258,7 +250,6 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
void CNNv3Effect::render(WGPUCommandEncoder encoder,
const UniformsSequenceParams& params,
NodeRegistry& nodes) {
- // Upload params uniforms.
enc0_params_buf_.update(ctx_.queue, enc0_params_);
enc1_params_buf_.update(ctx_.queue, enc1_params_);
bn_params_buf_.update(ctx_.queue, bn_params_);
@@ -270,7 +261,6 @@ void CNNv3Effect::render(WGPUCommandEncoder encoder,
const int W = (int)params.resolution.x;
const int H = (int)params.resolution.y;
- // Dispatch helper: ceil(dim / 8) workgroups
auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg,
int w, int h) {
WGPUComputePassDescriptor pass_desc = {};
@@ -304,14 +294,14 @@ void CNNv3Effect::create_pipelines() {
// --- enc0 ---
// B0: feat_tex0 (u32), B1: feat_tex1 (u32), B2: weights (storage),
- // B3: params (uniform), B4: enc0_out (storage_tex rgba16float write)
+ // B3: params (uniform, 96B), B4: enc0_out (rgba32uint write)
{
WGPUBindGroupLayoutEntry e[5] = {
bgl_uint_tex(0),
bgl_uint_tex(1),
bgl_storage_buf(2),
- bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes
- bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float),
+ bgl_uniform_buf(3, sizeof(CnnV3Params8ch)),
+ bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint),
};
WGPUBindGroupLayout bgl = make_bgl(dev, e, 5);
WGPUShaderModule sh = make_shader(dev, cnn_v3_enc0_wgsl);
@@ -321,16 +311,18 @@ void CNNv3Effect::create_pipelines() {
}
// --- enc1 ---
- // B0: enc0_tex (f32), B1: weights (storage),
- // B2: params (uniform), B3: enc1_out (storage_tex rgba32uint write)
+ // B0: enc0_tex (u32), B1: weights (storage),
+ // B2: params (uniform, 160B), B3: enc1_out_lo (rgba32uint write),
+ // B4: enc1_out_hi (rgba32uint write)
{
- WGPUBindGroupLayoutEntry e[4] = {
- bgl_float_tex(0),
+ WGPUBindGroupLayoutEntry e[5] = {
+ bgl_uint_tex(0),
bgl_storage_buf(1),
- bgl_uniform_buf(2, sizeof(CnnV3ParamsEnc1)),
+ bgl_uniform_buf(2, sizeof(CnnV3Params16ch)),
bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint),
+ bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint),
};
- WGPUBindGroupLayout bgl = make_bgl(dev, e, 4);
+ WGPUBindGroupLayout bgl = make_bgl(dev, e, 5);
WGPUShaderModule sh = make_shader(dev, cnn_v3_enc1_wgsl);
enc1_pipeline_.set(make_compute_pipeline(dev, sh, "enc1_main", bgl));
wgpuShaderModuleRelease(sh);
@@ -338,16 +330,19 @@ void CNNv3Effect::create_pipelines() {
}
// --- bottleneck ---
- // B0: enc1_tex (u32), B1: weights (storage),
- // B2: params (uniform), B3: bottleneck_out (storage_tex rgba32uint write)
+ // B0: enc1_tex_lo (u32), B1: enc1_tex_hi (u32), B2: weights (storage),
+ // B3: params (uniform, 16B), B4: bn_out_lo (rgba32uint write),
+ // B5: bn_out_hi (rgba32uint write)
{
- WGPUBindGroupLayoutEntry e[4] = {
+ WGPUBindGroupLayoutEntry e[6] = {
bgl_uint_tex(0),
- bgl_storage_buf(1),
- bgl_uniform_buf(2, sizeof(CnnV3ParamsBn)),
- bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint),
+ bgl_uint_tex(1),
+ bgl_storage_buf(2),
+ bgl_uniform_buf(3, sizeof(CnnV3ParamsBn)),
+ bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint),
+ bgl_storage_tex_write(5, WGPUTextureFormat_RGBA32Uint),
};
- WGPUBindGroupLayout bgl = make_bgl(dev, e, 4);
+ WGPUBindGroupLayout bgl = make_bgl(dev, e, 6);
WGPUShaderModule sh = make_shader(dev, cnn_v3_bottleneck_wgsl);
bn_pipeline_.set(make_compute_pipeline(dev, sh, "bottleneck_main", bgl));
wgpuShaderModuleRelease(sh);
@@ -355,17 +350,21 @@ void CNNv3Effect::create_pipelines() {
}
// --- dec1 ---
- // B0: bottleneck_tex (u32), B1: enc1_tex (u32), B2: weights (storage),
- // B3: params (uniform), B4: dec1_out (storage_tex rgba16float write)
+ // B0: bn_tex_lo (u32), B1: bn_tex_hi (u32),
+ // B2: enc1_tex_lo (u32), B3: enc1_tex_hi (u32),
+ // B4: weights (storage), B5: params (uniform, 96B),
+ // B6: dec1_out (rgba32uint write)
{
- WGPUBindGroupLayoutEntry e[5] = {
+ WGPUBindGroupLayoutEntry e[7] = {
bgl_uint_tex(0),
bgl_uint_tex(1),
- bgl_storage_buf(2),
- bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes
- bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float),
+ bgl_uint_tex(2),
+ bgl_uint_tex(3),
+ bgl_storage_buf(4),
+ bgl_uniform_buf(5, sizeof(CnnV3Params8ch)),
+ bgl_storage_tex_write(6, WGPUTextureFormat_RGBA32Uint),
};
- WGPUBindGroupLayout bgl = make_bgl(dev, e, 5);
+ WGPUBindGroupLayout bgl = make_bgl(dev, e, 7);
WGPUShaderModule sh = make_shader(dev, cnn_v3_dec1_wgsl);
dec1_pipeline_.set(make_compute_pipeline(dev, sh, "dec1_main", bgl));
wgpuShaderModuleRelease(sh);
@@ -373,14 +372,14 @@ void CNNv3Effect::create_pipelines() {
}
// --- dec0 ---
- // B0: dec1_tex (f32), B1: enc0_tex (f32), B2: weights (storage),
- // B3: params (uniform), B4: output_tex (storage_tex rgba16float write)
+ // B0: dec1_tex (u32), B1: enc0_tex (u32), B2: weights (storage),
+ // B3: params (uniform, 64B), B4: output_tex (rgba16float write)
{
WGPUBindGroupLayoutEntry e[5] = {
- bgl_float_tex(0),
- bgl_float_tex(1),
+ bgl_uint_tex(0),
+ bgl_uint_tex(1),
bgl_storage_buf(2),
- bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes
+ bgl_uniform_buf(3, sizeof(CnnV3Params4ch)),
bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float),
};
WGPUBindGroupLayout bgl = make_bgl(dev, e, 5);
@@ -395,14 +394,12 @@ void CNNv3Effect::create_pipelines() {
// update_bind_groups — rebuilt each frame (node views may be recreated)
// ---------------------------------------------------------------------------
-// Helper: set a texture view binding entry.
static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding,
WGPUTextureView view) {
e = {};
e.binding = binding;
e.textureView = view;
}
-// Helper: set a buffer binding entry.
static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf,
uint64_t size) {
e = {};
@@ -414,13 +411,15 @@ static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf,
void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) {
WGPUDevice dev = ctx_.device;
- WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]);
- WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]);
- WGPUTextureView enc0_view = nodes.get_view(node_enc0_);
- WGPUTextureView enc1_view = nodes.get_view(node_enc1_);
- WGPUTextureView bn_view = nodes.get_view(node_bottleneck_);
- WGPUTextureView dec1_view = nodes.get_view(node_dec1_);
- WGPUTextureView out_view = nodes.get_view(output_nodes_[0]);
+ WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]);
+ WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]);
+ WGPUTextureView enc0_view = nodes.get_view(node_enc0_);
+ WGPUTextureView enc1_lo_view = nodes.get_view(node_enc1_lo_);
+ WGPUTextureView enc1_hi_view = nodes.get_view(node_enc1_hi_);
+ WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_);
+ WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_);
+ WGPUTextureView dec1_view = nodes.get_view(node_dec1_);
+ WGPUTextureView out_view = nodes.get_view(output_nodes_[0]);
WGPUBuffer wb = weights_buf_.buffer;
@@ -437,49 +436,55 @@ void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) {
return bg;
};
- // enc0: feat_tex0(B0), feat_tex1(B1), weights(B2), params(B3), enc0_out(B4)
+ // enc0: feat0(B0), feat1(B1), weights(B2), params(B3), enc0_out(B4)
{
WGPUBindGroupEntry e[5] = {};
bg_tex(e[0], 0, feat0_view);
bg_tex(e[1], 1, feat1_view);
bg_buf(e[2], 2, wb, kWeightsBufBytes);
- bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params4ch));
+ bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params8ch));
bg_tex(e[4], 4, enc0_view);
enc0_bg_.replace(make_bg(enc0_pipeline_.get(), e, 5));
}
- // enc1: enc0_tex(B0), weights(B1), params(B2), enc1_out(B3)
+ // enc1: enc0(B0), weights(B1), params(B2), enc1_lo(B3), enc1_hi(B4)
{
- WGPUBindGroupEntry e[4] = {};
+ WGPUBindGroupEntry e[5] = {};
bg_tex(e[0], 0, enc0_view);
bg_buf(e[1], 1, wb, kWeightsBufBytes);
- bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3ParamsEnc1));
- bg_tex(e[3], 3, enc1_view);
- enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 4));
+ bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3Params16ch));
+ bg_tex(e[3], 3, enc1_lo_view);
+ bg_tex(e[4], 4, enc1_hi_view);
+ enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 5));
}
- // bottleneck: enc1_tex(B0), weights(B1), params(B2), bn_out(B3)
+ // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), bn_hi(B5)
{
- WGPUBindGroupEntry e[4] = {};
- bg_tex(e[0], 0, enc1_view);
- bg_buf(e[1], 1, wb, kWeightsBufBytes);
- bg_buf(e[2], 2, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn));
- bg_tex(e[3], 3, bn_view);
- bn_bg_.replace(make_bg(bn_pipeline_.get(), e, 4));
+ WGPUBindGroupEntry e[6] = {};
+ bg_tex(e[0], 0, enc1_lo_view);
+ bg_tex(e[1], 1, enc1_hi_view);
+ bg_buf(e[2], 2, wb, kWeightsBufBytes);
+ bg_buf(e[3], 3, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn));
+ bg_tex(e[4], 4, bn_lo_view);
+ bg_tex(e[5], 5, bn_hi_view);
+ bn_bg_.replace(make_bg(bn_pipeline_.get(), e, 6));
}
- // dec1: bn_tex(B0), enc1_tex(B1), weights(B2), params(B3), dec1_out(B4)
+ // dec1: bn_lo(B0), bn_hi(B1), enc1_lo(B2), enc1_hi(B3),
+ // weights(B4), params(B5), dec1_out(B6)
{
- WGPUBindGroupEntry e[5] = {};
- bg_tex(e[0], 0, bn_view);
- bg_tex(e[1], 1, enc1_view);
- bg_buf(e[2], 2, wb, kWeightsBufBytes);
- bg_buf(e[3], 3, dec1_params_buf_.get().buffer, sizeof(CnnV3Params4ch));
- bg_tex(e[4], 4, dec1_view);
- dec1_bg_.replace(make_bg(dec1_pipeline_.get(), e, 5));
+ WGPUBindGroupEntry e[7] = {};
+ bg_tex(e[0], 0, bn_lo_view);
+ bg_tex(e[1], 1, bn_hi_view);
+ bg_tex(e[2], 2, enc1_lo_view);
+ bg_tex(e[3], 3, enc1_hi_view);
+ bg_buf(e[4], 4, wb, kWeightsBufBytes);
+ bg_buf(e[5], 5, dec1_params_buf_.get().buffer, sizeof(CnnV3Params8ch));
+ bg_tex(e[6], 6, dec1_view);
+ dec1_bg_.replace(make_bg(dec1_pipeline_.get(), e, 7));
}
- // dec0: dec1_tex(B0), enc0_tex(B1), weights(B2), params(B3), output(B4)
+ // dec0: dec1(B0), enc0(B1), weights(B2), params(B3), output(B4)
{
WGPUBindGroupEntry e[5] = {};
bg_tex(e[0], 0, dec1_view);
diff --git a/cnn_v3/src/cnn_v3_effect.h b/cnn_v3/src/cnn_v3_effect.h
index 36e2797..070f988 100644
--- a/cnn_v3/src/cnn_v3_effect.h
+++ b/cnn_v3/src/cnn_v3_effect.h
@@ -2,6 +2,13 @@
// Runs 5 compute passes (enc0→enc1→bottleneck→dec1→dec0) on G-buffer feature
// textures produced by GBufferEffect.
//
+// Architecture: enc_channels=[8,16]
+// enc0: Conv(20→8, 3×3) + FiLM8 + ReLU H×W rgba32uint
+// enc1: Conv(8→16, 3×3) + FiLM16 + ReLU H/2×W/2 2× rgba32uint
+// bottleneck: Conv(16→16, 3×3, dil=2) + ReLU H/4×W/4 2× rgba32uint
+// dec1: Conv(32→8, 3×3) + FiLM8 + ReLU H/2×W/2 rgba32uint
+// dec0: Conv(16→4, 3×3) + FiLM4 + ReLU + sig H×W rgba16float
+//
// Inputs: feat_tex0, feat_tex1 (rgba32uint, 20-channel G-buffer)
// Output: output_tex (rgba16float, 4-channel RGBA)
@@ -18,35 +25,19 @@
// Per-pass params uniform layouts (mirror WGSL Params structs exactly)
// ---------------------------------------------------------------------------
-// enc0, dec1, dec0: 4-channel FiLM
+// enc0, dec1: 8-channel FiLM (lo/hi vec4 split)
//
-// WGSL layout (vec3u has align=16, so _pad sits at offset 16):
-// offset 0: weight_offset (u32, 4 bytes)
-// offset 4: (12 bytes implicit padding before vec3u)
-// offset 16: _pad (vec3u, 12 bytes)
-// offset 28: (4 bytes implicit padding before vec4f)
-// offset 32: gamma (vec4f, 16 bytes)
-// offset 48: beta (vec4f, 16 bytes)
-// total: 64 bytes
-struct CnnV3Params4ch {
- uint32_t weight_offset; // offset 0
- uint32_t _pad[7]; // offsets 4-31 (mirrors implicit + vec3u + post-pad)
- float gamma[4]; // offset 32
- float beta[4]; // offset 48
-};
-static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");
-
-// enc1: 8-channel FiLM (split into lo/hi vec4 pairs)
-//
-// WGSL layout (same header padding as above):
-// offset 0: weight_offset (u32, 4 bytes)
-// offset 16: _pad (vec3u, 12 bytes)
-// offset 32: gamma_lo (vec4f, 16 bytes)
-// offset 48: gamma_hi (vec4f, 16 bytes)
-// offset 64: beta_lo (vec4f, 16 bytes)
-// offset 80: beta_hi (vec4f, 16 bytes)
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 4-15: implicit pad, vec3u aligned to 16
+// offset 16: _pad (vec3u, 12 bytes)
+// offset 28-31: implicit pad
+// offset 32: gamma_lo (vec4f)
+// offset 48: gamma_hi (vec4f)
+// offset 64: beta_lo (vec4f)
+// offset 80: beta_hi (vec4f)
// total: 96 bytes
-struct CnnV3ParamsEnc1 {
+struct CnnV3Params8ch {
uint32_t weight_offset; // offset 0
uint32_t _pad[7]; // offsets 4-31
float gamma_lo[4]; // offset 32
@@ -54,10 +45,41 @@ struct CnnV3ParamsEnc1 {
float beta_lo[4]; // offset 64
float beta_hi[4]; // offset 80
};
-static_assert(sizeof(CnnV3ParamsEnc1) == 96,
- "CnnV3ParamsEnc1 must be 96 bytes");
+static_assert(sizeof(CnnV3Params8ch) == 96, "CnnV3Params8ch must be 96 bytes");
+
+// enc1: 16-channel FiLM (four vec4 groups for gamma + four for beta)
+//
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 16: _pad (vec3u)
+// offset 32: gamma_0..3 (4x vec4f = 64 bytes)
+// offset 96: beta_0..3 (4x vec4f = 64 bytes)
+// total: 160 bytes
+struct CnnV3Params16ch {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31
+ float gamma[16]; // offsets 32-95
+ float beta[16]; // offsets 96-159
+};
+static_assert(sizeof(CnnV3Params16ch) == 160, "CnnV3Params16ch must be 160 bytes");
+
+// dec0: 4-channel FiLM
+//
+// WGSL layout:
+// offset 0: weight_offset (u32)
+// offset 16: _pad (vec3u)
+// offset 32: gamma (vec4f)
+// offset 48: beta (vec4f)
+// total: 64 bytes
+struct CnnV3Params4ch {
+ uint32_t weight_offset; // offset 0
+ uint32_t _pad[7]; // offsets 4-31
+ float gamma[4]; // offset 32
+ float beta[4]; // offset 48
+};
+static_assert(sizeof(CnnV3Params4ch) == 64, "CnnV3Params4ch must be 64 bytes");
-// bottleneck: no FiLM — 4 plain u32s, no alignment gap
+// bottleneck: no FiLM — weight_offset + 3 pads
struct CnnV3ParamsBn {
uint32_t weight_offset;
uint32_t _pad[3];
@@ -90,14 +112,15 @@ class CNNv3Effect : public Effect {
void set_film_params(const CNNv3FiLMParams& fp);
// Upload packed-f16 weights (kWeightsBufBytes bytes of u32 pairs).
- // Used for testing and inference from trained .bin files.
void upload_weights(WGPUQueue queue, const void* data, uint32_t size_bytes);
private:
// Intermediate node names (prefixed from output[0])
std::string node_enc0_;
- std::string node_enc1_;
- std::string node_bottleneck_;
+ std::string node_enc1_lo_;
+ std::string node_enc1_hi_;
+ std::string node_bn_lo_;
+ std::string node_bn_hi_;
std::string node_dec1_;
// 5 compute pipelines
@@ -115,20 +138,20 @@ class CNNv3Effect : public Effect {
BindGroup dec0_bg_;
// Params uniform buffers (one per pass)
- UniformBuffer<CnnV3Params4ch> enc0_params_buf_;
- UniformBuffer<CnnV3ParamsEnc1> enc1_params_buf_;
+ UniformBuffer<CnnV3Params8ch> enc0_params_buf_;
+ UniformBuffer<CnnV3Params16ch> enc1_params_buf_;
UniformBuffer<CnnV3ParamsBn> bn_params_buf_;
- UniformBuffer<CnnV3Params4ch> dec1_params_buf_;
+ UniformBuffer<CnnV3Params8ch> dec1_params_buf_;
UniformBuffer<CnnV3Params4ch> dec0_params_buf_;
// Shared packed-f16 weights (storage buffer, read-only in all shaders)
GpuBuffer weights_buf_;
// Per-pass params shadow (updated by set_film_params, uploaded in render)
- CnnV3Params4ch enc0_params_{};
- CnnV3ParamsEnc1 enc1_params_{};
+ CnnV3Params8ch enc0_params_{};
+ CnnV3Params16ch enc1_params_{};
CnnV3ParamsBn bn_params_{};
- CnnV3Params4ch dec1_params_{};
+ CnnV3Params8ch dec1_params_{};
CnnV3Params4ch dec0_params_{};
void create_pipelines();