diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
| commit | 8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch) | |
| tree | 2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/src/cnn_v3_effect.cc | |
| parent | 4ca498277b033ae10134045dae9c8c249a8d2b2b (diff) | |
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch,
dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB).
FiLM MLP output 40→72 params (L1: 16×40→16×72).
16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck).
enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch).
GBUF_RGBA32UINT node gains CopySrc for parity test readback.
- WGSL shaders: all 5 passes rewritten for new channel counts
- C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs
- Web tool (shaders.js + tester.js): matching texture formats and bindings
- Parity test: readback_rgba32uint_8ch helper, updated vector counts
- Training scripts: default enc_channels=[8,16], updated docstrings
- Docs + architecture PNG regenerated
handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web
tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 247 |
1 files changed, 126 insertions, 121 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index 1391eba..dc26751 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -1,5 +1,5 @@ // CNN v3 Effect — U-Net + FiLM inference (5 compute passes) -// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md §7 for shader details. +// See cnn_v3/docs/CNN_V3.md for architecture, HOWTO.md for shader details. #include "cnn_v3_effect.h" @@ -17,17 +17,16 @@ #include <cstring> // --------------------------------------------------------------------------- -// Weight layout constants — explicit formulas matching WGSL shader comments -// --------------------------------------------------------------------------- +// Weight layout constants — enc_channels=[8,16] // // Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases -// Layout: OIHW order (out × in × kH × kW), biases appended after conv weights -// -static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias -static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias -static const uint32_t kBnWeights = 8 * 8 * 9 + 8; // Conv(8→8,3×3,dilation=2)+bias -static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias -static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias +// Layout: OIHW order (out × in × kH × kW), biases appended +// --------------------------------------------------------------------------- +static const uint32_t kEnc0Weights = 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448 +static const uint32_t kEnc1Weights = 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168 +static const uint32_t kBnWeights = 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320 +static const uint32_t kDec1Weights = 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312 +static const uint32_t kDec0Weights = 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580 static const uint32_t kEnc0Offset = 0; static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; @@ -35,13 +34,12 @@ static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; static const uint32_t kDec1Offset = kBnOffset + kBnWeights; static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; +// = 1448 + 1168 + 2320 + 2312 + 580 = 7828 f16 -// Weights buffer size in bytes: f16 values are packed two-per-u32. -// Round up to u32 boundary. static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4; // --------------------------------------------------------------------------- -// Shader source externs (registered in shaders.cc via InitShaderComposer) +// Shader source externs // --------------------------------------------------------------------------- extern const char* cnn_v3_enc0_wgsl; extern const char* cnn_v3_enc1_wgsl; @@ -103,14 +101,6 @@ static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) { e.texture.viewDimension = WGPUTextureViewDimension_2D; return e; } -static WGPUBindGroupLayoutEntry bgl_float_tex(uint32_t binding) { - WGPUBindGroupLayoutEntry e = {}; - e.binding = binding; - e.visibility = WGPUShaderStage_Compute; - e.texture.sampleType = WGPUTextureSampleType_Float; - e.texture.viewDimension = WGPUTextureViewDimension_2D; - return e; -} static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; e.binding = binding; @@ -151,45 +141,46 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, const std::string& prefix = outputs.empty() ? std::string("cnn_v3") : outputs[0]; - node_enc0_ = prefix + "_enc0"; - node_enc1_ = prefix + "_enc1"; - node_bottleneck_ = prefix + "_bottleneck"; - node_dec1_ = prefix + "_dec1"; + node_enc0_ = prefix + "_enc0"; + node_enc1_lo_ = prefix + "_enc1_lo"; + node_enc1_hi_ = prefix + "_enc1_hi"; + node_bn_lo_ = prefix + "_bn_lo"; + node_bn_hi_ = prefix + "_bn_hi"; + node_dec1_ = prefix + "_dec1"; - // Allocate zeroed weights buffer (f16 pairs packed as u32). - // Weights are zero-initialized; load_weights() can fill from file later. weights_buf_ = gpu_create_buffer( ctx_.device, kWeightsBufBytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - // Initialize uniform buffers. enc0_params_buf_.init(ctx_.device); enc1_params_buf_.init(ctx_.device); bn_params_buf_.init(ctx_.device); dec1_params_buf_.init(ctx_.device); dec0_params_buf_.init(ctx_.device); - // Set weight offsets (FiLM γ/β default to identity: γ=1, β=0). + // Set weight offsets; FiLM γ/β default to identity (γ=1, β=0). enc0_params_.weight_offset = kEnc0Offset; - for (int i = 0; i < 4; ++i) { enc0_params_.gamma[i] = 1.0f; } - - enc1_params_.weight_offset = kEnc1Offset; for (int i = 0; i < 4; ++i) { - enc1_params_.gamma_lo[i] = 1.0f; - enc1_params_.gamma_hi[i] = 1.0f; + enc0_params_.gamma_lo[i] = 1.0f; + enc0_params_.gamma_hi[i] = 1.0f; } + enc1_params_.weight_offset = kEnc1Offset; + for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f; } + bn_params_.weight_offset = kBnOffset; dec1_params_.weight_offset = kDec1Offset; - for (int i = 0; i < 4; ++i) { dec1_params_.gamma[i] = 1.0f; } + for (int i = 0; i < 4; ++i) { + dec1_params_.gamma_lo[i] = 1.0f; + dec1_params_.gamma_hi[i] = 1.0f; + } dec0_params_.weight_offset = kDec0Offset; for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; } create_pipelines(); - // Load trained weights from asset system (zero-initialized if absent). size_t weights_size = 0; const void* weights_data = GetAsset(AssetId::ASSET_WEIGHTS_CNN_V3, &weights_size); @@ -206,20 +197,21 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) { const int W = registry.default_width(); const int H = registry.default_height(); - // enc0_tex: rgba16float full-res - registry.declare_node(node_enc0_, NodeType::GBUF_ALBEDO, W, H); - // enc1_tex: rgba32uint half-res — shaders use textureDimensions() for bounds - registry.declare_node(node_enc1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); - // bottleneck_tex: rgba32uint quarter-res - registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); - // dec1_tex: rgba16float half-res - registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2); + // enc0: rgba32uint full-res (8ch packed f16) + registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H); + // enc1: two rgba32uint half-res (8ch each = 16ch total) + registry.declare_node(node_enc1_lo_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); + registry.declare_node(node_enc1_hi_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); + // bottleneck: two rgba32uint quarter-res (8ch each = 16ch total) + registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); + registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); + // dec1: rgba32uint half-res (8ch packed f16) + registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // output_nodes_[0]: rgba16float full-res — declared externally by caller } // --------------------------------------------------------------------------- -// set_film_params — simple linear mapping (placeholder, no MLP yet) -// TODO(phase-7): replace with CPU forward pass through cnn_v3_film_mlp.bin +// upload_weights / set_film_params // --------------------------------------------------------------------------- void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, @@ -228,26 +220,26 @@ void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, } void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { - // Identity + audio/beat modulation. - // Replace with FiLM MLP output once training is done. const float a = fp.audio_intensity; const float b = fp.beat_phase; for (int i = 0; i < 4; ++i) { - enc0_params_.gamma[i] = 1.0f + a * 0.5f; - enc0_params_.beta[i] = b * 0.1f; + enc0_params_.gamma_lo[i] = 1.0f + a * 0.5f; + enc0_params_.gamma_hi[i] = 1.0f + a * 0.5f; + enc0_params_.beta_lo[i] = b * 0.1f; + enc0_params_.beta_hi[i] = b * 0.1f; } - for (int i = 0; i < 4; ++i) { - enc1_params_.gamma_lo[i] = 1.0f + a * 0.3f; - enc1_params_.gamma_hi[i] = 1.0f + a * 0.3f; - enc1_params_.beta_lo[i] = fp.beat_norm * 0.1f; - enc1_params_.beta_hi[i] = fp.beat_norm * 0.1f; + for (int i = 0; i < 16; ++i) { + enc1_params_.gamma[i] = 1.0f + a * 0.3f; + enc1_params_.beta[i] = fp.beat_norm * 0.1f; } for (int i = 0; i < 4; ++i) { - dec1_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; - dec1_params_.beta[i] = fp.style_p1 * 0.1f; - dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; - dec0_params_.beta[i] = fp.style_p1 * 0.1f; + dec1_params_.gamma_lo[i] = 1.0f + fp.style_p0 * 0.5f; + dec1_params_.gamma_hi[i] = 1.0f + fp.style_p0 * 0.5f; + dec1_params_.beta_lo[i] = fp.style_p1 * 0.1f; + dec1_params_.beta_hi[i] = fp.style_p1 * 0.1f; + dec0_params_.gamma[i] = 1.0f + fp.style_p0 * 0.5f; + dec0_params_.beta[i] = fp.style_p1 * 0.1f; } } @@ -258,7 +250,6 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { void CNNv3Effect::render(WGPUCommandEncoder encoder, const UniformsSequenceParams& params, NodeRegistry& nodes) { - // Upload params uniforms. enc0_params_buf_.update(ctx_.queue, enc0_params_); enc1_params_buf_.update(ctx_.queue, enc1_params_); bn_params_buf_.update(ctx_.queue, bn_params_); @@ -270,7 +261,6 @@ void CNNv3Effect::render(WGPUCommandEncoder encoder, const int W = (int)params.resolution.x; const int H = (int)params.resolution.y; - // Dispatch helper: ceil(dim / 8) workgroups auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, int w, int h) { WGPUComputePassDescriptor pass_desc = {}; @@ -304,14 +294,14 @@ void CNNv3Effect::create_pipelines() { // --- enc0 --- // B0: feat_tex0 (u32), B1: feat_tex1 (u32), B2: weights (storage), - // B3: params (uniform), B4: enc0_out (storage_tex rgba16float write) + // B3: params (uniform, 96B), B4: enc0_out (rgba32uint write) { WGPUBindGroupLayoutEntry e[5] = { bgl_uint_tex(0), bgl_uint_tex(1), bgl_storage_buf(2), - bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes - bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), + bgl_uniform_buf(3, sizeof(CnnV3Params8ch)), + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc0_wgsl); @@ -321,16 +311,18 @@ void CNNv3Effect::create_pipelines() { } // --- enc1 --- - // B0: enc0_tex (f32), B1: weights (storage), - // B2: params (uniform), B3: enc1_out (storage_tex rgba32uint write) + // B0: enc0_tex (u32), B1: weights (storage), + // B2: params (uniform, 160B), B3: enc1_out_lo (rgba32uint write), + // B4: enc1_out_hi (rgba32uint write) { - WGPUBindGroupLayoutEntry e[4] = { - bgl_float_tex(0), + WGPUBindGroupLayoutEntry e[5] = { + bgl_uint_tex(0), bgl_storage_buf(1), - bgl_uniform_buf(2, sizeof(CnnV3ParamsEnc1)), + bgl_uniform_buf(2, sizeof(CnnV3Params16ch)), bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), }; - WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); + WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); WGPUShaderModule sh = make_shader(dev, cnn_v3_enc1_wgsl); enc1_pipeline_.set(make_compute_pipeline(dev, sh, "enc1_main", bgl)); wgpuShaderModuleRelease(sh); @@ -338,16 +330,19 @@ void CNNv3Effect::create_pipelines() { } // --- bottleneck --- - // B0: enc1_tex (u32), B1: weights (storage), - // B2: params (uniform), B3: bottleneck_out (storage_tex rgba32uint write) + // B0: enc1_tex_lo (u32), B1: enc1_tex_hi (u32), B2: weights (storage), + // B3: params (uniform, 16B), B4: bn_out_lo (rgba32uint write), + // B5: bn_out_hi (rgba32uint write) { - WGPUBindGroupLayoutEntry e[4] = { + WGPUBindGroupLayoutEntry e[6] = { bgl_uint_tex(0), - bgl_storage_buf(1), - bgl_uniform_buf(2, sizeof(CnnV3ParamsBn)), - bgl_storage_tex_write(3, WGPUTextureFormat_RGBA32Uint), + bgl_uint_tex(1), + bgl_storage_buf(2), + bgl_uniform_buf(3, sizeof(CnnV3ParamsBn)), + bgl_storage_tex_write(4, WGPUTextureFormat_RGBA32Uint), + bgl_storage_tex_write(5, WGPUTextureFormat_RGBA32Uint), }; - WGPUBindGroupLayout bgl = make_bgl(dev, e, 4); + WGPUBindGroupLayout bgl = make_bgl(dev, e, 6); WGPUShaderModule sh = make_shader(dev, cnn_v3_bottleneck_wgsl); bn_pipeline_.set(make_compute_pipeline(dev, sh, "bottleneck_main", bgl)); wgpuShaderModuleRelease(sh); @@ -355,17 +350,21 @@ void CNNv3Effect::create_pipelines() { } // --- dec1 --- - // B0: bottleneck_tex (u32), B1: enc1_tex (u32), B2: weights (storage), - // B3: params (uniform), B4: dec1_out (storage_tex rgba16float write) + // B0: bn_tex_lo (u32), B1: bn_tex_hi (u32), + // B2: enc1_tex_lo (u32), B3: enc1_tex_hi (u32), + // B4: weights (storage), B5: params (uniform, 96B), + // B6: dec1_out (rgba32uint write) { - WGPUBindGroupLayoutEntry e[5] = { + WGPUBindGroupLayoutEntry e[7] = { bgl_uint_tex(0), bgl_uint_tex(1), - bgl_storage_buf(2), - bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes - bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), + bgl_uint_tex(2), + bgl_uint_tex(3), + bgl_storage_buf(4), + bgl_uniform_buf(5, sizeof(CnnV3Params8ch)), + bgl_storage_tex_write(6, WGPUTextureFormat_RGBA32Uint), }; - WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); + WGPUBindGroupLayout bgl = make_bgl(dev, e, 7); WGPUShaderModule sh = make_shader(dev, cnn_v3_dec1_wgsl); dec1_pipeline_.set(make_compute_pipeline(dev, sh, "dec1_main", bgl)); wgpuShaderModuleRelease(sh); @@ -373,14 +372,14 @@ void CNNv3Effect::create_pipelines() { } // --- dec0 --- - // B0: dec1_tex (f32), B1: enc0_tex (f32), B2: weights (storage), - // B3: params (uniform), B4: output_tex (storage_tex rgba16float write) + // B0: dec1_tex (u32), B1: enc0_tex (u32), B2: weights (storage), + // B3: params (uniform, 64B), B4: output_tex (rgba16float write) { WGPUBindGroupLayoutEntry e[5] = { - bgl_float_tex(0), - bgl_float_tex(1), + bgl_uint_tex(0), + bgl_uint_tex(1), bgl_storage_buf(2), - bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), // 64 bytes + bgl_uniform_buf(3, sizeof(CnnV3Params4ch)), bgl_storage_tex_write(4, WGPUTextureFormat_RGBA16Float), }; WGPUBindGroupLayout bgl = make_bgl(dev, e, 5); @@ -395,14 +394,12 @@ void CNNv3Effect::create_pipelines() { // update_bind_groups — rebuilt each frame (node views may be recreated) // --------------------------------------------------------------------------- -// Helper: set a texture view binding entry. static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding, WGPUTextureView view) { e = {}; e.binding = binding; e.textureView = view; } -// Helper: set a buffer binding entry. static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, uint64_t size) { e = {}; @@ -414,13 +411,15 @@ static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { WGPUDevice dev = ctx_.device; - WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); - WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); - WGPUTextureView enc0_view = nodes.get_view(node_enc0_); - WGPUTextureView enc1_view = nodes.get_view(node_enc1_); - WGPUTextureView bn_view = nodes.get_view(node_bottleneck_); - WGPUTextureView dec1_view = nodes.get_view(node_dec1_); - WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); + WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); + WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); + WGPUTextureView enc0_view = nodes.get_view(node_enc0_); + WGPUTextureView enc1_lo_view = nodes.get_view(node_enc1_lo_); + WGPUTextureView enc1_hi_view = nodes.get_view(node_enc1_hi_); + WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_); + WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_); + WGPUTextureView dec1_view = nodes.get_view(node_dec1_); + WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); WGPUBuffer wb = weights_buf_.buffer; @@ -437,49 +436,55 @@ void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { return bg; }; - // enc0: feat_tex0(B0), feat_tex1(B1), weights(B2), params(B3), enc0_out(B4) + // enc0: feat0(B0), feat1(B1), weights(B2), params(B3), enc0_out(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, feat0_view); bg_tex(e[1], 1, feat1_view); bg_buf(e[2], 2, wb, kWeightsBufBytes); - bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); + bg_buf(e[3], 3, enc0_params_buf_.get().buffer, sizeof(CnnV3Params8ch)); bg_tex(e[4], 4, enc0_view); enc0_bg_.replace(make_bg(enc0_pipeline_.get(), e, 5)); } - // enc1: enc0_tex(B0), weights(B1), params(B2), enc1_out(B3) + // enc1: enc0(B0), weights(B1), params(B2), enc1_lo(B3), enc1_hi(B4) { - WGPUBindGroupEntry e[4] = {}; + WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, enc0_view); bg_buf(e[1], 1, wb, kWeightsBufBytes); - bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3ParamsEnc1)); - bg_tex(e[3], 3, enc1_view); - enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 4)); + bg_buf(e[2], 2, enc1_params_buf_.get().buffer, sizeof(CnnV3Params16ch)); + bg_tex(e[3], 3, enc1_lo_view); + bg_tex(e[4], 4, enc1_hi_view); + enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 5)); } - // bottleneck: enc1_tex(B0), weights(B1), params(B2), bn_out(B3) + // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), bn_hi(B5) { - WGPUBindGroupEntry e[4] = {}; - bg_tex(e[0], 0, enc1_view); - bg_buf(e[1], 1, wb, kWeightsBufBytes); - bg_buf(e[2], 2, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn)); - bg_tex(e[3], 3, bn_view); - bn_bg_.replace(make_bg(bn_pipeline_.get(), e, 4)); + WGPUBindGroupEntry e[6] = {}; + bg_tex(e[0], 0, enc1_lo_view); + bg_tex(e[1], 1, enc1_hi_view); + bg_buf(e[2], 2, wb, kWeightsBufBytes); + bg_buf(e[3], 3, bn_params_buf_.get().buffer, sizeof(CnnV3ParamsBn)); + bg_tex(e[4], 4, bn_lo_view); + bg_tex(e[5], 5, bn_hi_view); + bn_bg_.replace(make_bg(bn_pipeline_.get(), e, 6)); } - // dec1: bn_tex(B0), enc1_tex(B1), weights(B2), params(B3), dec1_out(B4) + // dec1: bn_lo(B0), bn_hi(B1), enc1_lo(B2), enc1_hi(B3), + // weights(B4), params(B5), dec1_out(B6) { - WGPUBindGroupEntry e[5] = {}; - bg_tex(e[0], 0, bn_view); - bg_tex(e[1], 1, enc1_view); - bg_buf(e[2], 2, wb, kWeightsBufBytes); - bg_buf(e[3], 3, dec1_params_buf_.get().buffer, sizeof(CnnV3Params4ch)); - bg_tex(e[4], 4, dec1_view); - dec1_bg_.replace(make_bg(dec1_pipeline_.get(), e, 5)); + WGPUBindGroupEntry e[7] = {}; + bg_tex(e[0], 0, bn_lo_view); + bg_tex(e[1], 1, bn_hi_view); + bg_tex(e[2], 2, enc1_lo_view); + bg_tex(e[3], 3, enc1_hi_view); + bg_buf(e[4], 4, wb, kWeightsBufBytes); + bg_buf(e[5], 5, dec1_params_buf_.get().buffer, sizeof(CnnV3Params8ch)); + bg_tex(e[6], 6, dec1_view); + dec1_bg_.replace(make_bg(dec1_pipeline_.get(), e, 7)); } - // dec0: dec1_tex(B0), enc0_tex(B1), weights(B2), params(B3), output(B4) + // dec0: dec1(B0), enc0(B1), weights(B2), params(B3), output(B4) { WGPUBindGroupEntry e[5] = {}; bg_tex(e[0], 0, dec1_view); |
