summaryrefslogtreecommitdiff
path: root/cnn_v3/src/cnn_v3_effect.cc
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-05-21 08:10:47 +0200
committerskal <pascal.massimino@gmail.com>2026-05-21 08:10:47 +0200
commitd806027dcaeadcdd8d2febd88bc46b2fd2c465de (patch)
tree30bc1ef9f40ccab7c00e31ee20e62bb86755fa26 /cnn_v3/src/cnn_v3_effect.cc
parent680042a18c11ad5e58757e45b260745c2f52417f (diff)
style: apply clang-formatHEADmain
Diffstat (limited to 'cnn_v3/src/cnn_v3_effect.cc')
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc185
1 files changed, 100 insertions, 85 deletions
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index e576ceb..fa1716f 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -22,18 +22,23 @@
// Format: Conv(IN→OUT, KxK) has OUT*IN*K*K weights + OUT biases
// Layout: OIHW order (out × in × kH × kW), biases appended
// ---------------------------------------------------------------------------
-static const uint32_t kEnc0Weights = 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448
-static const uint32_t kEnc1Weights = 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168
-static const uint32_t kBnWeights = 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320
-static const uint32_t kDec1Weights = 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312
-static const uint32_t kDec0Weights = 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580
+static const uint32_t kEnc0Weights =
+ 20 * 8 * 9 + 8; // Conv(20→8, 3×3)+bias = 1448
+static const uint32_t kEnc1Weights =
+ 8 * 16 * 9 + 16; // Conv(8→16, 3×3)+bias = 1168
+static const uint32_t kBnWeights =
+ 16 * 16 * 9 + 16; // Conv(16→16, 3×3,dil=2)+bias = 2320
+static const uint32_t kDec1Weights =
+ 32 * 8 * 9 + 8; // Conv(32→8, 3×3)+bias = 2312
+static const uint32_t kDec0Weights =
+ 16 * 4 * 9 + 4; // Conv(16→4, 3×3)+bias = 580
-static const uint32_t kEnc0Offset = 0;
-static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights;
-static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights;
-static const uint32_t kDec1Offset = kBnOffset + kBnWeights;
-static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights;
-static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights;
+static const uint32_t kEnc0Offset = 0;
+static const uint32_t kEnc1Offset = kEnc0Offset + kEnc0Weights;
+static const uint32_t kBnOffset = kEnc1Offset + kEnc1Weights;
+static const uint32_t kDec1Offset = kBnOffset + kBnWeights;
+static const uint32_t kDec0Offset = kDec1Offset + kDec1Weights;
+static const uint32_t kTotalF16 = kDec0Offset + kDec0Weights;
// = 1448 + 1168 + 2320 + 2312 + 580 = 7828 f16
static const uint32_t kWeightsBufBytes = ((kTotalF16 + 1) / 2) * 4;
@@ -57,7 +62,7 @@ static WGPUShaderModule make_shader(WGPUDevice device, const char* wgsl) {
WGPUShaderSourceWGSL src = {};
src.chain.sType = WGPUSType_ShaderSourceWGSL;
- src.code = str_view(composed.c_str());
+ src.code = str_view(composed.c_str());
WGPUShaderModuleDescriptor desc = {};
desc.nextInChain = &src.chain;
@@ -69,7 +74,7 @@ static WGPUBindGroupLayout make_bgl(WGPUDevice device,
uint32_t count) {
WGPUBindGroupLayoutDescriptor desc = {};
desc.entryCount = count;
- desc.entries = entries;
+ desc.entries = entries;
return wgpuDeviceCreateBindGroupLayout(device, &desc);
}
@@ -79,14 +84,15 @@ static WGPUComputePipeline make_compute_pipeline(WGPUDevice device,
WGPUBindGroupLayout bgl) {
WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
- pl_desc.bindGroupLayouts = &bgl;
+ pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc);
WGPUComputePipelineDescriptor pipe_desc = {};
- pipe_desc.layout = pl;
- pipe_desc.compute.module = shader;
- pipe_desc.compute.entryPoint = str_view(entry);
- WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &pipe_desc);
+ pipe_desc.layout = pl;
+ pipe_desc.compute.module = shader;
+ pipe_desc.compute.entryPoint = str_view(entry);
+ WGPUComputePipeline pipe =
+ wgpuDeviceCreateComputePipeline(device, &pipe_desc);
wgpuPipelineLayoutRelease(pl);
return pipe;
@@ -95,36 +101,36 @@ static WGPUComputePipeline make_compute_pipeline(WGPUDevice device,
// BGL entry helpers
static WGPUBindGroupLayoutEntry bgl_uint_tex(uint32_t binding) {
WGPUBindGroupLayoutEntry e = {};
- e.binding = binding;
- e.visibility = WGPUShaderStage_Compute;
- e.texture.sampleType = WGPUTextureSampleType_Uint;
- e.texture.viewDimension = WGPUTextureViewDimension_2D;
+ e.binding = binding;
+ e.visibility = WGPUShaderStage_Compute;
+ e.texture.sampleType = WGPUTextureSampleType_Uint;
+ e.texture.viewDimension = WGPUTextureViewDimension_2D;
return e;
}
static WGPUBindGroupLayoutEntry bgl_storage_buf(uint32_t binding) {
WGPUBindGroupLayoutEntry e = {};
- e.binding = binding;
- e.visibility = WGPUShaderStage_Compute;
- e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
+ e.binding = binding;
+ e.visibility = WGPUShaderStage_Compute;
+ e.buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
return e;
}
static WGPUBindGroupLayoutEntry bgl_uniform_buf(uint32_t binding,
uint64_t min_size) {
WGPUBindGroupLayoutEntry e = {};
- e.binding = binding;
- e.visibility = WGPUShaderStage_Compute;
- e.buffer.type = WGPUBufferBindingType_Uniform;
- e.buffer.minBindingSize = min_size;
+ e.binding = binding;
+ e.visibility = WGPUShaderStage_Compute;
+ e.buffer.type = WGPUBufferBindingType_Uniform;
+ e.buffer.minBindingSize = min_size;
return e;
}
-static WGPUBindGroupLayoutEntry bgl_storage_tex_write(
- uint32_t binding, WGPUTextureFormat fmt) {
+static WGPUBindGroupLayoutEntry bgl_storage_tex_write(uint32_t binding,
+ WGPUTextureFormat fmt) {
WGPUBindGroupLayoutEntry e = {};
- e.binding = binding;
- e.visibility = WGPUShaderStage_Compute;
- e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly;
- e.storageTexture.format = fmt;
- e.storageTexture.viewDimension = WGPUTextureViewDimension_2D;
+ e.binding = binding;
+ e.visibility = WGPUShaderStage_Compute;
+ e.storageTexture.access = WGPUStorageTextureAccess_WriteOnly;
+ e.storageTexture.format = fmt;
+ e.storageTexture.viewDimension = WGPUTextureViewDimension_2D;
return e;
}
@@ -141,16 +147,16 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
const std::string& prefix =
outputs.empty() ? std::string("cnn_v3") : outputs[0];
- node_enc0_ = prefix + "_enc0";
+ node_enc0_ = prefix + "_enc0";
node_enc1_lo_ = prefix + "_enc1_lo";
node_enc1_hi_ = prefix + "_enc1_hi";
- node_bn_lo_ = prefix + "_bn_lo";
- node_bn_hi_ = prefix + "_bn_hi";
- node_dec1_ = prefix + "_dec1";
+ node_bn_lo_ = prefix + "_bn_lo";
+ node_bn_hi_ = prefix + "_bn_hi";
+ node_dec1_ = prefix + "_dec1";
- weights_buf_ = gpu_create_buffer(
- ctx_.device, kWeightsBufBytes,
- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst);
+ weights_buf_ =
+ gpu_create_buffer(ctx_.device, kWeightsBufBytes,
+ WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst);
enc0_params_buf_.init(ctx_.device);
enc1_params_buf_.init(ctx_.device);
@@ -166,7 +172,9 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
}
enc1_params_.weight_offset = kEnc1Offset;
- for (int i = 0; i < 16; ++i) { enc1_params_.gamma[i] = 1.0f; }
+ for (int i = 0; i < 16; ++i) {
+ enc1_params_.gamma[i] = 1.0f;
+ }
bn_params_.weight_offset = kBnOffset;
@@ -177,7 +185,9 @@ CNNv3Effect::CNNv3Effect(const GpuContext& ctx,
}
dec0_params_.weight_offset = kDec0Offset;
- for (int i = 0; i < 4; ++i) { dec0_params_.gamma[i] = 1.0f; }
+ for (int i = 0; i < 4; ++i) {
+ dec0_params_.gamma[i] = 1.0f;
+ }
create_pipelines();
@@ -205,15 +215,15 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
const int H = registry.default_height();
// enc0: rgba32uint full-res (8ch packed f16)
- registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H);
+ registry.declare_node(node_enc0_, NodeType::GBUF_RGBA32UINT, W, H);
// enc1: two rgba32uint half-res (8ch each = 16ch total)
registry.declare_node(node_enc1_lo_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
registry.declare_node(node_enc1_hi_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
// bottleneck: two rgba32uint quarter-res (8ch each = 16ch total)
- registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
- registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
+ registry.declare_node(node_bn_lo_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
+ registry.declare_node(node_bn_hi_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
// dec1: rgba32uint half-res (8ch packed f16)
- registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
+ registry.declare_node(node_dec1_, NodeType::GBUF_RGBA32UINT, W / 2, H / 2);
// output_nodes_[0]: rgba16float full-res — declared externally by caller
}
@@ -222,12 +232,13 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
// ---------------------------------------------------------------------------
void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
- uint32_t size_bytes) {
+ uint32_t size_bytes) {
wgpuQueueWriteBuffer(queue, weights_buf_.buffer, 0, data, size_bytes);
}
void CNNv3Effect::load_film_mlp(const void* data, uint32_t size_bytes) {
- if (size_bytes != sizeof(CNNv3FilmMlp)) return;
+ if (size_bytes != sizeof(CNNv3FilmMlp))
+ return;
memcpy(&mlp_, data, sizeof(CNNv3FilmMlp));
mlp_loaded_ = true;
}
@@ -246,16 +257,19 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
float h[16];
for (int j = 0; j < 16; ++j) {
float s = mlp_.l0_b[j];
- for (int i = 0; i < 5; ++i) s += mlp_.l0_w[j * 5 + i] * cond[i];
+ for (int i = 0; i < 5; ++i)
+ s += mlp_.l0_w[j * 5 + i] * cond[i];
h[j] = s > 0.f ? s : 0.f;
}
// Layer 1: Linear(16→72)
- // Output split: g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4)
+ // Output split:
+ // g_enc0(8)|b_enc0(8)|g_enc1(16)|b_enc1(16)|g_dec1(8)|b_dec1(8)|g_dec0(4)|b_dec0(4)
float film[72];
for (int j = 0; j < 72; ++j) {
float s = mlp_.l1_b[j];
- for (int i = 0; i < 16; ++i) s += mlp_.l1_w[j * 16 + i] * h[i];
+ for (int i = 0; i < 16; ++i)
+ s += mlp_.l1_w[j * 16 + i] * h[i];
film[j] = s;
}
@@ -270,9 +284,11 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
enc0_params_.beta_hi[i] = p[i + 4];
}
p += 8;
- for (int i = 0; i < 16; ++i) enc1_params_.gamma[i] = p[i];
+ for (int i = 0; i < 16; ++i)
+ enc1_params_.gamma[i] = p[i];
p += 16;
- for (int i = 0; i < 16; ++i) enc1_params_.beta[i] = p[i];
+ for (int i = 0; i < 16; ++i)
+ enc1_params_.beta[i] = p[i];
p += 16;
for (int i = 0; i < 4; ++i) {
dec1_params_.gamma_lo[i] = p[i];
@@ -284,9 +300,11 @@ void CNNv3Effect::set_film_params(const CNNv3FiLMParams& fp) {
dec1_params_.beta_hi[i] = p[i + 4];
}
p += 8;
- for (int i = 0; i < 4; ++i) dec0_params_.gamma[i] = p[i];
+ for (int i = 0; i < 4; ++i)
+ dec0_params_.gamma[i] = p[i];
p += 4;
- for (int i = 0; i < 4; ++i) dec0_params_.beta[i] = p[i];
+ for (int i = 0; i < 4; ++i)
+ dec0_params_.beta[i] = p[i];
}
// ---------------------------------------------------------------------------
@@ -307,27 +325,24 @@ void CNNv3Effect::render(WGPUCommandEncoder encoder,
const int W = (int)params.resolution.x;
const int H = (int)params.resolution.y;
- auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg,
- int w, int h) {
+ auto dispatch = [&](WGPUComputePipeline pipe, WGPUBindGroup bg, int w,
+ int h) {
WGPUComputePassDescriptor pass_desc = {};
WGPUComputePassEncoder pass =
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
wgpuComputePassEncoderSetPipeline(pass, pipe);
wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr);
- wgpuComputePassEncoderDispatchWorkgroups(
- pass,
- (uint32_t)((w + 7) / 8),
- (uint32_t)((h + 7) / 8),
- 1);
+ wgpuComputePassEncoderDispatchWorkgroups(pass, (uint32_t)((w + 7) / 8),
+ (uint32_t)((h + 7) / 8), 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
};
- dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H);
- dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2);
- dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4);
- dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2);
- dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H);
+ dispatch(enc0_pipeline_.get(), enc0_bg_.get(), W, H);
+ dispatch(enc1_pipeline_.get(), enc1_bg_.get(), W / 2, H / 2);
+ dispatch(bn_pipeline_.get(), bn_bg_.get(), W / 4, H / 4);
+ dispatch(dec1_pipeline_.get(), dec1_bg_.get(), W / 2, H / 2);
+ dispatch(dec0_pipeline_.get(), dec0_bg_.get(), W, H);
}
// ---------------------------------------------------------------------------
@@ -443,40 +458,39 @@ void CNNv3Effect::create_pipelines() {
static void bg_tex(WGPUBindGroupEntry& e, uint32_t binding,
WGPUTextureView view) {
e = {};
- e.binding = binding;
+ e.binding = binding;
e.textureView = view;
}
static void bg_buf(WGPUBindGroupEntry& e, uint32_t binding, WGPUBuffer buf,
uint64_t size) {
e = {};
e.binding = binding;
- e.buffer = buf;
- e.size = size;
+ e.buffer = buf;
+ e.size = size;
}
void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) {
WGPUDevice dev = ctx_.device;
- WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]);
- WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]);
- WGPUTextureView enc0_view = nodes.get_view(node_enc0_);
+ WGPUTextureView feat0_view = nodes.get_view(input_nodes_[0]);
+ WGPUTextureView feat1_view = nodes.get_view(input_nodes_[1]);
+ WGPUTextureView enc0_view = nodes.get_view(node_enc0_);
WGPUTextureView enc1_lo_view = nodes.get_view(node_enc1_lo_);
WGPUTextureView enc1_hi_view = nodes.get_view(node_enc1_hi_);
- WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_);
- WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_);
- WGPUTextureView dec1_view = nodes.get_view(node_dec1_);
- WGPUTextureView out_view = nodes.get_view(output_nodes_[0]);
+ WGPUTextureView bn_lo_view = nodes.get_view(node_bn_lo_);
+ WGPUTextureView bn_hi_view = nodes.get_view(node_bn_hi_);
+ WGPUTextureView dec1_view = nodes.get_view(node_dec1_);
+ WGPUTextureView out_view = nodes.get_view(output_nodes_[0]);
WGPUBuffer wb = weights_buf_.buffer;
auto make_bg = [&](WGPUComputePipeline pipe, WGPUBindGroupEntry* e,
uint32_t count) -> WGPUBindGroup {
- WGPUBindGroupLayout bgl =
- wgpuComputePipelineGetBindGroupLayout(pipe, 0);
+ WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pipe, 0);
WGPUBindGroupDescriptor desc = {};
- desc.layout = bgl;
+ desc.layout = bgl;
desc.entryCount = count;
- desc.entries = e;
+ desc.entries = e;
WGPUBindGroup bg = wgpuDeviceCreateBindGroup(dev, &desc);
wgpuBindGroupLayoutRelease(bgl);
return bg;
@@ -504,7 +518,8 @@ void CNNv3Effect::update_bind_groups(NodeRegistry& nodes) {
enc1_bg_.replace(make_bg(enc1_pipeline_.get(), e, 5));
}
- // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4), bn_hi(B5)
+ // bottleneck: enc1_lo(B0), enc1_hi(B1), weights(B2), params(B3), bn_lo(B4),
+ // bn_hi(B5)
{
WGPUBindGroupEntry e[6] = {};
bg_tex(e[0], 0, enc1_lo_view);