diff options
| author | skal <pascal.massimino@gmail.com> | 2026-05-21 08:10:47 +0200 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-05-21 08:10:47 +0200 |
| commit | d806027dcaeadcdd8d2febd88bc46b2fd2c465de (patch) | |
| tree | 30bc1ef9f40ccab7c00e31ee20e62bb86755fa26 /cnn_v3/src/cnn_v3_effect.cc | |
| parent | 680042a18c11ad5e58757e45b260745c2f52417f (diff) | |
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 185 |
1 files changed, 100 insertions, 85 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index e576ceb..fa1716f 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -22,18 +22,23 @@ // Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases // Layout: OIHW order (out × in × kH × kW), biases appended // --------------------------------------------------------------------------- -static const uint32_t kEnc0Weights = 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448 -static const uint32_t kEnc1Weights = 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168 -static const uint32_t kBnWeights = 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320 -static const uint32_t kDec1Weights = 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312 -static const uint32_t kDec0Weights = 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580 +static const uint32_t kEnc0Weights = + 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448 +static const uint32_t kEnc1Weights = + 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168 +static const uint32_t kBnWeights = + 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320 +static const uint32_t kDec1Weights = + 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312 +static const uint32_t kDec0Weights = + 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580 -static const uint32_t kEnc0Offset = 0; -static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; -static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; -static const uint32_t kDec1Offset = kBnOffset + kBnWeights; -static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; -static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; +static const uint32_t kEnc0Offset = 0; +static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights; +static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights; +static const uint32_t kDec1Offset = kBnOffset + kBnWeights; +static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights; +static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights; // = 1448 + 1168 + 2320 + 2312 + 580 = 7828 f16 static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4; @@ -57,7 +62,7 @@ static WGPUShaderModule make_shader(WGPUDevice device, const char* wgsl) { WGPUShaderSourceWGSL src = {}; src.chain.sType = WGPUSType_ShaderSourceWGSL; - src.code = str_view(composed.c_str()); + src.code = str_view(composed.c_str()); WGPUShaderModuleDescriptor desc = {}; desc.nextInChain = &src.chain; @@ -69,7 +74,7 @@ static WGPUBindGroupLayout make_bgl(WGPUDevice device, uint32_t count) { WGPUBindGroupLayoutDescriptor desc = {}; desc.entryCount = count; - desc.entries = entries; + desc.entries = entries; return wgpuDeviceCreateBindGroupLayout(device, &desc); } @@ -79,14 +84,15 @@ static WGPUComputePipeline make_compute_pipeline(WGPUDevice device, WGPUBindGroupLayout bgl) { WGPUPipelineLayoutDescriptor pl_desc = {}; pl_desc.bindGroupLayoutCount = 1; - pl_desc.bindGroupLayouts = &bgl; + pl_desc.bindGroupLayouts = &bgl; WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc); WGPUComputePipelineDescriptor pipe_desc = {}; - pipe_desc.layout = pl; - pipe_desc.compute.module = shader; - pipe_desc.compute.entryPoint = str_view(entry); - WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &pipe_desc); + pipe_desc.layout = pl; + pipe_desc.compute.module = shader; + pipe_desc.compute.entryPoint = str_view(entry); + WGPUComputePipeline pipe = + wgpuDeviceCreateComputePipeline(device, &pipe_desc); wgpuPipelineLayoutRelease(pl); return pipe; @@ -95,36 +101,36 @@ static WGPUComputePipeline make_compute_pipeline(WGPUDevice device, // BGL entry helpers static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; - e.binding = binding; - e.visibility = WGPUShaderStage_Compute; - e.texture.sampleType = WGPUTextureSampleType_Uint; - e.texture.viewDimension = WGPUTextureViewDimension_2D; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.texture.sampleType = WGPUTextureSampleType_Uint; + e.texture.viewDimension = WGPUTextureViewDimension_2D; return e; } static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) { WGPUBindGroupLayoutEntry e = {}; - e.binding = binding; - e.visibility = WGPUShaderStage_Compute; - e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage; return e; } static WGPUBindGroupLayoutEntry bgl_uniform_buf(uint32_t binding, uint64_t min_size) { WGPUBindGroupLayoutEntry e = {}; - e.binding = binding; - e.visibility = WGPUShaderStage_Compute; - e.buffer.type = WGPUBufferBindingType_Uniform; - e.buffer.minBindingSize = min_size; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.buffer.type = WGPUBufferBindingType_Uniform; + e.buffer.minBindingSize = min_size; return e; } -static WGPUBindGroupLayoutEntry bgl_storage_tex_write( - uint32_t binding, WGPUTextureFormat fmt) { +static WGPUBindGroupLayoutEntry bgl_storage_tex_write(uint32_t binding, + WGPUTextureFormat fmt) { WGPUBindGroupLayoutEntry e = {}; - e.binding = binding; - e.visibility = WGPUShaderStage_Compute; - e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly; - e.storageTexture.format = fmt; - e.storageTexture.viewDimension = WGPUTextureViewDimension_2D; + e.binding = binding; + e.visibility = WGPUShaderStage_Compute; + e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly; + e.storageTexture.format = fmt; + e.storageTexture.viewDimension = WGPUTextureViewDimension_2D; return e; } @@ -141,16 +147,16 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, const std::string& prefix = outputs.empty() ? std::string("cnn_v3") : outputs[0]; - node_enc0_ = prefix + "_enc0"; + node_enc0_ = prefix + "_enc0"; node_enc1_lo_ = prefix + "_enc1_lo"; node_enc1_hi_ = prefix + "_enc1_hi"; - node_bn_lo_ = prefix + "_bn_lo"; - node_bn_hi_ = prefix + "_bn_hi"; - node_dec1_ = prefix + "_dec1"; + node_bn_lo_ = prefix + "_bn_lo"; + node_bn_hi_ = prefix + "_bn_hi"; + node_dec1_ = prefix + "_dec1"; - weights_buf_ = gpu_create_buffer( - ctx_.device, kWeightsBufBytes, - WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); + weights_buf_ = + gpu_create_buffer(ctx_.device, kWeightsBufBytes, + WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); enc0_params_buf_.init(ctx_.device); enc1_params_buf_.init(ctx_.device); @@ -166,7 +172,9 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, } enc1_params_.weight_offset = kEnc1Offset; - for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f; } + for (int i = 0; i < 16; ++i) { + enc1_params_.gamma[i] = 1.0f; + } bn_params_.weight_offset = kBnOffset; @@ -177,7 +185,9 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx, } dec0_params_.weight_offset = kDec0Offset; - for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; } + for (int i = 0; i < 4; ++i) { + dec0_params_.gamma[i] = 1.0f; + } create_pipelines(); @@ -205,15 +215,15 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) { const int H = registry.default_height(); // enc0: rgba32uint full-res (8ch packed f16) - registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H); + registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H); // enc1: two rgba32uint half-res (8ch each = 16ch total) registry.declare_node(node_enc1_lo_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); registry.declare_node(node_enc1_hi_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // bottleneck: two rgba32uint quarter-res (8ch each = 16ch total) - registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); - registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); + registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); + registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); // dec1: rgba32uint half-res (8ch packed f16) - registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); + registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2); // output_nodes_[0]: rgba16float full-res — declared externally by caller } @@ -222,12 +232,13 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) { // --------------------------------------------------------------------------- void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, - uint32_t size_bytes) { + uint32_t size_bytes) { wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes); } void CNNv3Effect::load_film_mlp(const void* data, uint32_t size_bytes) { - if (size_bytes != sizeof(CNNv3FilmMlp)) return; + if (size_bytes != sizeof(CNNv3FilmMlp)) + return; memcpy(&mlp_, data, sizeof(CNNv3FilmMlp)); mlp_loaded_ = true; } @@ -246,16 +257,19 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { float h[16]; for (int j = 0; j < 16; ++j) { float s = mlp_.l0_b[j]; - for (int i = 0; i < 5; ++i) s += mlp_.l0_w[j * 5 + i] * cond[i]; + for (int i = 0; i < 5; ++i) + s += mlp_.l0_w[j * 5 + i] * cond[i]; h[j] = s > 0.f ? s : 0.f; } // Layer 1: Linear(16→72) - // Output split: g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4) + // Output split: + // g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4) float film[72]; for (int j = 0; j < 72; ++j) { float s = mlp_.l1_b[j]; - for (int i = 0; i < 16; ++i) s += mlp_.l1_w[j * 16 + i] * h[i]; + for (int i = 0; i < 16; ++i) + s += mlp_.l1_w[j * 16 + i] * h[i]; film[j] = s; } @@ -270,9 +284,11 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { enc0_params_.beta_hi[i] = p[i + 4]; } p += 8; - for (int i = 0; i < 16; ++i) enc1_params_.gamma[i] = p[i]; + for (int i = 0; i < 16; ++i) + enc1_params_.gamma[i] = p[i]; p += 16; - for (int i = 0; i < 16; ++i) enc1_params_.beta[i] = p[i]; + for (int i = 0; i < 16; ++i) + enc1_params_.beta[i] = p[i]; p += 16; for (int i = 0; i < 4; ++i) { dec1_params_.gamma_lo[i] = p[i]; @@ -284,9 +300,11 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) { dec1_params_.beta_hi[i] = p[i + 4]; } p += 8; - for (int i = 0; i < 4; ++i) dec0_params_.gamma[i] = p[i]; + for (int i = 0; i < 4; ++i) + dec0_params_.gamma[i] = p[i]; p += 4; - for (int i = 0; i < 4; ++i) dec0_params_.beta[i] = p[i]; + for (int i = 0; i < 4; ++i) + dec0_params_.beta[i] = p[i]; } // --------------------------------------------------------------------------- @@ -307,27 +325,24 @@ void CNNv3Effect::render(WGPUCommandEncoder encoder, const int W = (int)params.resolution.x; const int H = (int)params.resolution.y; - auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, - int w, int h) { + auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, int w, + int h) { WGPUComputePassDescriptor pass_desc = {}; WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); wgpuComputePassEncoderSetPipeline(pass, pipe); wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr); - wgpuComputePassEncoderDispatchWorkgroups( - pass, - (uint32_t)((w + 7) / 8), - (uint32_t)((h + 7) / 8), - 1); + wgpuComputePassEncoderDispatchWorkgroups(pass, (uint32_t)((w + 7) / 8), + (uint32_t)((h + 7) / 8), 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); }; - dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H); - dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2); - dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4); - dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2); - dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H); + dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H); + dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2); + dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4); + dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2); + dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H); } // --------------------------------------------------------------------------- @@ -443,40 +458,39 @@ void CNNv3Effect::create_pipelines() { static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding, WGPUTextureView view) { e = {}; - e.binding = binding; + e.binding = binding; e.textureView = view; } static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf, uint64_t size) { e = {}; e.binding = binding; - e.buffer = buf; - e.size = size; + e.buffer = buf; + e.size = size; } void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { WGPUDevice dev = ctx_.device; - WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); - WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); - WGPUTextureView enc0_view = nodes.get_view(node_enc0_); + WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]); + WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]); + WGPUTextureView enc0_view = nodes.get_view(node_enc0_); WGPUTextureView enc1_lo_view = nodes.get_view(node_enc1_lo_); WGPUTextureView enc1_hi_view = nodes.get_view(node_enc1_hi_); - WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_); - WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_); - WGPUTextureView dec1_view = nodes.get_view(node_dec1_); - WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); + WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_); + WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_); + WGPUTextureView dec1_view = nodes.get_view(node_dec1_); + WGPUTextureView out_view = nodes.get_view(output_nodes_[0]); WGPUBuffer wb = weights_buf_.buffer; auto make_bg = [&](WGPUComputePipeline pipe, WGPUBindGroupEntry* e, uint32_t count) -> WGPUBindGroup { - WGPUBindGroupLayout bgl = - wgpuComputePipelineGetBindGroupLayout(pipe, 0); + WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pipe, 0); WGPUBindGroupDescriptor desc = {}; - desc.layout = bgl; + desc.layout = bgl; desc.entryCount = count; - desc.entries = e; + desc.entries = e; WGPUBindGroup bg = wgpuDeviceCreateBindGroup(dev, &desc); wgpuBindGroupLayoutRelease(bgl); return bg; @@ -504,7 +518,8 @@ void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) { enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 5)); } - // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), bn_hi(B5) + // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), + // bn_hi(B5) { WGPUBindGroupEntry e[6] = {}; bg_tex(e[0], 0, enc1_lo_view); |
