summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 12:08:22 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 12:08:22 +0100
commit4d87a6d781c3f159d216f4cd9251e3d7bd63554f (patch)
tree61bb4ee18b1c981cee789b215adf73860138d6c2
parent4cbf571a0087020bedf3c565483f94bc795ed4c4 (diff)
CNN v2: storage buffer architecture foundation
- Add binary weight format (header + layer info + packed f16) - New export_cnn_v2_weights.py for binary weight export - Single cnn_v2_compute.wgsl shader with storage buffer - Load weights in CNNv2Effect::load_weights() - Create layer compute pipeline with 5 bindings - Fast training config: 100 epochs, 3×3 kernels, 8→4→4 channels Next: Complete bind group creation and multi-layer compute execution
-rw-r--r--checkpoints/checkpoint_epoch_10.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_15.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_20.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_25.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_30.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_35.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_40.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_45.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_5.pthbin0 -> 24325 bytes
-rw-r--r--checkpoints/checkpoint_epoch_50.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_55.pthbin0 -> 24343 bytes
-rw-r--r--checkpoints/checkpoint_epoch_60.pthbin0 -> 24343 bytes
-rwxr-xr-xscripts/train_cnn_v2_full.sh10
-rw-r--r--src/gpu/effects/cnn_v2_effect.cc140
-rw-r--r--src/gpu/effects/cnn_v2_effect.h30
-rwxr-xr-xtraining/export_cnn_v2_weights.py272
-rw-r--r--workspaces/main/assets.txt3
-rw-r--r--workspaces/main/shaders/cnn_v2_compute.wgsl136
18 files changed, 576 insertions, 15 deletions
diff --git a/checkpoints/checkpoint_epoch_10.pth b/checkpoints/checkpoint_epoch_10.pth
new file mode 100644
index 0000000..710315a
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_10.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_15.pth b/checkpoints/checkpoint_epoch_15.pth
new file mode 100644
index 0000000..e7e78d4
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_15.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_20.pth b/checkpoints/checkpoint_epoch_20.pth
new file mode 100644
index 0000000..4d4dc10
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_20.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_25.pth b/checkpoints/checkpoint_epoch_25.pth
new file mode 100644
index 0000000..60da2f2
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_25.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_30.pth b/checkpoints/checkpoint_epoch_30.pth
new file mode 100644
index 0000000..2b0a340
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_30.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_35.pth b/checkpoints/checkpoint_epoch_35.pth
new file mode 100644
index 0000000..839e368
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_35.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_40.pth b/checkpoints/checkpoint_epoch_40.pth
new file mode 100644
index 0000000..b299337
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_40.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_45.pth b/checkpoints/checkpoint_epoch_45.pth
new file mode 100644
index 0000000..f629261
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_45.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_5.pth b/checkpoints/checkpoint_epoch_5.pth
new file mode 100644
index 0000000..bca35d9
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_5.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_50.pth b/checkpoints/checkpoint_epoch_50.pth
new file mode 100644
index 0000000..03795aa
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_50.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_55.pth b/checkpoints/checkpoint_epoch_55.pth
new file mode 100644
index 0000000..0a6c7b6
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_55.pth
Binary files differ
diff --git a/checkpoints/checkpoint_epoch_60.pth b/checkpoints/checkpoint_epoch_60.pth
new file mode 100644
index 0000000..7e40bbf
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_60.pth
Binary files differ
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 4ddd9ac..383922d 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -12,8 +12,8 @@ INPUT_DIR="training/input"
TARGET_DIR="training/target_2"
CHECKPOINT_DIR="checkpoints"
VALIDATION_DIR="validation_results"
-EPOCHS=10000
-CHECKPOINT_EVERY=500
+EPOCHS=100
+CHECKPOINT_EVERY=5
BATCH_SIZE=16
# Patch-based training (default)
@@ -25,8 +25,8 @@ DETECTOR="harris"
# FULL_IMAGE="--full-image"
# IMAGE_SIZE=256
-KERNEL_SIZES="1 3 5"
-CHANNELS="16 8 4"
+KERNEL_SIZES="3 3 3"
+CHANNELS="8 4 4"
echo "=== CNN v2 Complete Training Pipeline ==="
echo "Input: $INPUT_DIR"
@@ -98,7 +98,7 @@ mkdir -p "$VALIDATION_DIR"
# Test first input image with checkpoints at intervals
TEST_IMAGE="$INPUT_DIR/img_000.png"
-CHECKPOINT_INTERVAL=1000
+CHECKPOINT_INTERVAL=5
echo " Processing checkpoints (every ${CHECKPOINT_INTERVAL} epochs)..."
diff --git a/src/gpu/effects/cnn_v2_effect.cc b/src/gpu/effects/cnn_v2_effect.cc
index b425aba..275af68 100644
--- a/src/gpu/effects/cnn_v2_effect.cc
+++ b/src/gpu/effects/cnn_v2_effect.cc
@@ -18,6 +18,9 @@ CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
static_bind_group_(nullptr),
static_features_tex_(nullptr),
static_features_view_(nullptr),
+ layer_pipeline_(nullptr),
+ weights_buffer_(nullptr),
+ layer_params_buffer_(nullptr),
input_mip_tex_(nullptr),
current_input_view_(nullptr),
initialized_(false) {
@@ -32,6 +35,7 @@ void CNNv2Effect::init(MainSequence* demo) {
(void)demo;
if (initialized_) return;
+ load_weights();
create_textures();
create_pipelines();
@@ -45,6 +49,59 @@ void CNNv2Effect::resize(int width, int height) {
create_pipelines();
}
+void CNNv2Effect::load_weights() {
+ // Load binary weights asset
+ size_t weights_size = 0;
+ const uint8_t* weights_data = (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size);
+
+ if (!weights_data || weights_size < 16) {
+ // Weights not available - effect will skip
+ return;
+ }
+
+ // Parse header (16 bytes)
+ const uint32_t* header = (const uint32_t*)weights_data;
+ uint32_t magic = header[0];
+ uint32_t version = header[1];
+ uint32_t num_layers = header[2];
+ uint32_t total_weights = header[3];
+
+ FATAL_CHECK(magic == 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2'
+ FATAL_CHECK(version == 1, "Unsupported CNN v2 weights version\n");
+
+ // Parse layer info (20 bytes per layer)
+ const uint32_t* layer_data = header + 4;
+ for (uint32_t i = 0; i < num_layers; ++i) {
+ LayerInfo info;
+ info.kernel_size = layer_data[i * 5 + 0];
+ info.in_channels = layer_data[i * 5 + 1];
+ info.out_channels = layer_data[i * 5 + 2];
+ info.weight_offset = layer_data[i * 5 + 3];
+ info.weight_count = layer_data[i * 5 + 4];
+ layer_info_.push_back(info);
+ }
+
+ // Create GPU storage buffer for weights
+ // Buffer contains: header + layer info + packed f16 weights (as u32)
+ WGPUBufferDescriptor buffer_desc = {};
+ buffer_desc.size = weights_size;
+ buffer_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst;
+ buffer_desc.mappedAtCreation = false;
+
+ weights_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &buffer_desc);
+
+ // Upload weights data
+ wgpuQueueWriteBuffer(ctx_.queue, weights_buffer_, 0, weights_data, weights_size);
+
+ // Create uniform buffer for layer params
+ WGPUBufferDescriptor params_desc = {};
+ params_desc.size = sizeof(LayerParams);
+ params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
+ params_desc.mappedAtCreation = false;
+
+ layer_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &params_desc);
+}
+
void CNNv2Effect::create_textures() {
const WGPUExtent3D size = {
static_cast<uint32_t>(width_),
@@ -208,8 +265,80 @@ void CNNv2Effect::create_pipelines() {
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBindGroupLayoutRelease(static_bgl);
- // Bind group will be created in update_bind_group()
- // TODO: Create layer pipelines
+ // CNN layer compute pipeline (storage buffer version)
+ if (layer_info_.empty()) return; // No weights loaded
+
+ size_t layer_shader_size = 0;
+ const char* layer_code = (const char*)GetAsset(AssetId::ASSET_SHADER_CNN_V2_COMPUTE, &layer_shader_size);
+
+ if (!layer_code || layer_shader_size == 0) return;
+
+ WGPUShaderSourceWGSL layer_wgsl = {};
+ layer_wgsl.chain.sType = WGPUSType_ShaderSourceWGSL;
+ layer_wgsl.code = str_view(layer_code);
+
+ WGPUShaderModuleDescriptor layer_shader_desc = {};
+ layer_shader_desc.nextInChain = &layer_wgsl.chain;
+
+ WGPUShaderModule layer_module = wgpuDeviceCreateShaderModule(ctx_.device, &layer_shader_desc);
+ if (!layer_module) return;
+
+ // Create bind group layout for layer compute
+ // 0=static_features, 1=layer_input, 2=output, 3=weights, 4=params
+ WGPUBindGroupLayoutEntry layer_bgl_entries[5] = {};
+
+ // Binding 0: Static features (texture)
+ layer_bgl_entries[0].binding = 0;
+ layer_bgl_entries[0].visibility = WGPUShaderStage_Compute;
+ layer_bgl_entries[0].texture.sampleType = WGPUTextureSampleType_Uint;
+ layer_bgl_entries[0].texture.viewDimension = WGPUTextureViewDimension_2D;
+
+ // Binding 1: Layer input (texture)
+ layer_bgl_entries[1].binding = 1;
+ layer_bgl_entries[1].visibility = WGPUShaderStage_Compute;
+ layer_bgl_entries[1].texture.sampleType = WGPUTextureSampleType_Uint;
+ layer_bgl_entries[1].texture.viewDimension = WGPUTextureViewDimension_2D;
+
+ // Binding 2: Output (storage texture)
+ layer_bgl_entries[2].binding = 2;
+ layer_bgl_entries[2].visibility = WGPUShaderStage_Compute;
+ layer_bgl_entries[2].storageTexture.access = WGPUStorageTextureAccess_WriteOnly;
+ layer_bgl_entries[2].storageTexture.format = WGPUTextureFormat_RGBA32Uint;
+ layer_bgl_entries[2].storageTexture.viewDimension = WGPUTextureViewDimension_2D;
+
+ // Binding 3: Weights (storage buffer)
+ layer_bgl_entries[3].binding = 3;
+ layer_bgl_entries[3].visibility = WGPUShaderStage_Compute;
+ layer_bgl_entries[3].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
+
+ // Binding 4: Layer params (uniform buffer)
+ layer_bgl_entries[4].binding = 4;
+ layer_bgl_entries[4].visibility = WGPUShaderStage_Compute;
+ layer_bgl_entries[4].buffer.type = WGPUBufferBindingType_Uniform;
+ layer_bgl_entries[4].buffer.minBindingSize = sizeof(LayerParams);
+
+ WGPUBindGroupLayoutDescriptor layer_bgl_desc = {};
+ layer_bgl_desc.entryCount = 5;
+ layer_bgl_desc.entries = layer_bgl_entries;
+
+ WGPUBindGroupLayout layer_bgl = wgpuDeviceCreateBindGroupLayout(ctx_.device, &layer_bgl_desc);
+
+ WGPUPipelineLayoutDescriptor layer_pl_desc = {};
+ layer_pl_desc.bindGroupLayoutCount = 1;
+ layer_pl_desc.bindGroupLayouts = &layer_bgl;
+
+ WGPUPipelineLayout layer_pipeline_layout = wgpuDeviceCreatePipelineLayout(ctx_.device, &layer_pl_desc);
+
+ WGPUComputePipelineDescriptor layer_pipeline_desc = {};
+ layer_pipeline_desc.compute.module = layer_module;
+ layer_pipeline_desc.compute.entryPoint = str_view("main");
+ layer_pipeline_desc.layout = layer_pipeline_layout;
+
+ layer_pipeline_ = wgpuDeviceCreateComputePipeline(ctx_.device, &layer_pipeline_desc);
+
+ wgpuShaderModuleRelease(layer_module);
+ wgpuPipelineLayoutRelease(layer_pipeline_layout);
+ wgpuBindGroupLayoutRelease(layer_bgl);
}
void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
@@ -292,6 +421,10 @@ void CNNv2Effect::cleanup() {
if (static_bind_group_) wgpuBindGroupRelease(static_bind_group_);
if (static_pipeline_) wgpuComputePipelineRelease(static_pipeline_);
+ if (layer_pipeline_) wgpuComputePipelineRelease(layer_pipeline_);
+ if (weights_buffer_) wgpuBufferRelease(weights_buffer_);
+ if (layer_params_buffer_) wgpuBufferRelease(layer_params_buffer_);
+
for (int i = 0; i < 3; ++i) {
if (input_mip_view_[i]) wgpuTextureViewRelease(input_mip_view_[i]);
}
@@ -300,12 +433,11 @@ void CNNv2Effect::cleanup() {
for (auto view : layer_views_) wgpuTextureViewRelease(view);
for (auto tex : layer_textures_) wgpuTextureRelease(tex);
for (auto bg : layer_bind_groups_) wgpuBindGroupRelease(bg);
- for (auto pipeline : layer_pipelines_) wgpuComputePipelineRelease(pipeline);
layer_views_.clear();
layer_textures_.clear();
layer_bind_groups_.clear();
- layer_pipelines_.clear();
+ layer_info_.clear();
initialized_ = false;
}
diff --git a/src/gpu/effects/cnn_v2_effect.h b/src/gpu/effects/cnn_v2_effect.h
index facf4c3..6005cf5 100644
--- a/src/gpu/effects/cnn_v2_effect.h
+++ b/src/gpu/effects/cnn_v2_effect.h
@@ -19,8 +19,25 @@ public:
void update_bind_group(WGPUTextureView input_view) override;
private:
+ struct LayerInfo {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t weight_count;
+ };
+
+ struct LayerParams {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t is_output_layer;
+ };
+
void create_textures();
void create_pipelines();
+ void load_weights();
void cleanup();
// Static features compute
@@ -29,16 +46,19 @@ private:
WGPUTexture static_features_tex_;
WGPUTextureView static_features_view_;
- // CNN layers (opaque implementation)
- std::vector<WGPUComputePipeline> layer_pipelines_;
- std::vector<WGPUBindGroup> layer_bind_groups_;
- std::vector<WGPUTexture> layer_textures_;
+ // CNN layers (storage buffer architecture)
+ WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers
+ WGPUBuffer weights_buffer_; // Storage buffer for weights
+ WGPUBuffer layer_params_buffer_; // Uniform buffer for per-layer params
+ std::vector<LayerInfo> layer_info_; // Layer metadata
+ std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups
+ std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers
std::vector<WGPUTextureView> layer_views_;
// Input mips
WGPUTexture input_mip_tex_;
WGPUTextureView input_mip_view_[3];
- WGPUTextureView current_input_view_; // Cached input from update_bind_group
+ WGPUTextureView current_input_view_;
bool initialized_;
};
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
new file mode 100755
index 0000000..05d4958
--- /dev/null
+++ b/training/export_cnn_v2_weights.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+"""CNN v2 Weight Export Script
+
+Converts PyTorch checkpoints to binary weight format for storage buffer.
+Exports single shader template + binary weights asset.
+"""
+
+import argparse
+import numpy as np
+import torch
+import struct
+from pathlib import Path
+
+
+def export_weights_binary(checkpoint_path, output_path):
+ """Export CNN v2 weights to binary format.
+
+ Binary format:
+ Header (16 bytes):
+ uint32 magic ('CNN2')
+ uint32 version (1)
+ uint32 num_layers
+ uint32 total_weights (f16 count)
+
+ LayerInfo × num_layers (20 bytes each):
+ uint32 kernel_size
+ uint32 in_channels
+ uint32 out_channels
+ uint32 weight_offset (f16 index)
+ uint32 weight_count
+
+ Weights (f16 array):
+ float16[] all_weights
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_path: Output .bin file path
+
+ Returns:
+ config dict for shader generation
+ """
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ print(f"Configuration:")
+ print(f" Kernels: {config['kernels']}")
+ print(f" Channels: {config['channels']}")
+
+ # Collect layer info
+ layers = []
+ all_weights = []
+ weight_offset = 0
+
+ # Layer 0: 8 → channels[0]
+ layer0_weights = state_dict['layer0.weight'].detach().numpy()
+ layer0_flat = layer0_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][0],
+ 'in_channels': 8,
+ 'out_channels': config['channels'][0],
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer0_flat)
+ })
+ all_weights.extend(layer0_flat)
+ weight_offset += len(layer0_flat)
+
+ # Layer 1: (8 + channels[0]) → channels[1]
+ layer1_weights = state_dict['layer1.weight'].detach().numpy()
+ layer1_flat = layer1_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][1],
+ 'in_channels': 8 + config['channels'][0],
+ 'out_channels': config['channels'][1],
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer1_flat)
+ })
+ all_weights.extend(layer1_flat)
+ weight_offset += len(layer1_flat)
+
+ # Layer 2: (8 + channels[1]) → 4 (RGBA output)
+ layer2_weights = state_dict['layer2.weight'].detach().numpy()
+ layer2_flat = layer2_weights.flatten()
+ layers.append({
+ 'kernel_size': config['kernels'][2],
+ 'in_channels': 8 + config['channels'][1],
+ 'out_channels': 4,
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer2_flat)
+ })
+ all_weights.extend(layer2_flat)
+ weight_offset += len(layer2_flat)
+
+ # Convert to f16
+ all_weights_f16 = np.array(all_weights, dtype=np.float16)
+
+ # Pack f16 pairs into u32 for storage buffer
+ # Pad to even count if needed
+ if len(all_weights_f16) % 2 == 1:
+ all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))
+
+ # Pack pairs using numpy view
+ weights_u32 = all_weights_f16.view(np.uint32)
+
+ print(f"\nWeight statistics:")
+ print(f" Total layers: {len(layers)}")
+ print(f" Total weights: {len(all_weights_f16)} (f16)")
+ print(f" Packed: {len(weights_u32)} u32")
+ print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes")
+
+ # Write binary file
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, 'wb') as f:
+ # Header (16 bytes)
+ f.write(struct.pack('<4sIII',
+ b'CNN2', # magic
+ 1, # version
+ len(layers), # num_layers
+ len(all_weights_f16))) # total_weights (f16 count)
+
+ # Layer info (20 bytes per layer)
+ for layer in layers:
+ f.write(struct.pack('<IIIII',
+ layer['kernel_size'],
+ layer['in_channels'],
+ layer['out_channels'],
+ layer['weight_offset'],
+ layer['weight_count']))
+
+ # Weights (u32 packed f16 pairs)
+ f.write(weights_u32.tobytes())
+
+ print(f" → {output_path}")
+
+ return {
+ 'num_layers': len(layers),
+ 'layers': layers
+ }
+
+
+def export_shader_template(config, output_dir):
+ """Generate single WGSL shader template with storage buffer binding.
+
+ Args:
+ config: Layer configuration from export_weights_binary()
+ output_dir: Output directory path
+ """
+ shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
+// Reads weights from storage buffer, processes all layers in sequence
+
+struct CNNv2Header {
+ magic: u32, // 'CNN2'
+ version: u32, // 1
+ num_layers: u32, // Number of layers
+ total_weights: u32, // Total f16 weight count
+}
+
+struct CNNv2LayerInfo {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in weights array
+ weight_count: u32,
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+fn get_weight(idx: u32) -> f32 {
+ let pair_idx = idx / 2u;
+ let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32)
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Read header
+ let header_packed = weights[0]; // magic + version
+ let counts_packed = weights[1]; // num_layers + total_weights
+ let num_layers = counts_packed & 0xFFFFu;
+
+ // Load static features
+ let static_feat = unpack_static_features(coord);
+
+ // Process each layer (hardcoded for 3 layers for now)
+ // TODO: Dynamic layer loop when needed
+
+ // Example for layer 0 - expand to full multi-layer when tested
+ let layer_info_offset = 2u; // After header
+ let layer0_info_base = layer_info_offset;
+
+ // Read layer 0 info (5 u32 values = 20 bytes)
+ let kernel_size = weights[layer0_info_base];
+ let in_channels = weights[layer0_info_base + 1u];
+ let out_channels = weights[layer0_info_base + 2u];
+ let weight_offset = weights[layer0_info_base + 3u];
+
+ // Convolution (simplified - expand to full kernel loop)
+ var output: array<f32, 8>;
+ for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
+ output[c] = 0.0; // TODO: Actual convolution
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}
+"""
+
+ output_path = Path(output_dir) / "cnn_v2_compute.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin',
+ help='Output binary weights file')
+ parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
+ help='Output directory for shader template')
+
+ args = parser.parse_args()
+
+ print("=== CNN v2 Weight Export ===\n")
+ config = export_weights_binary(args.checkpoint, args.output_weights)
+ print()
+ export_shader_template(config, args.output_shader)
+ print("\nExport complete!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/workspaces/main/assets.txt b/workspaces/main/assets.txt
index 280d6ed..4cbbb0f 100644
--- a/workspaces/main/assets.txt
+++ b/workspaces/main/assets.txt
@@ -44,7 +44,8 @@ SHADER_CNN_CONV7X7, NONE, shaders/cnn/cnn_conv7x7.wgsl, "CNN 7x7 Convolution"
SHADER_CNN_WEIGHTS, NONE, shaders/cnn/cnn_weights_generated.wgsl, "CNN Weights (Generated)"
SHADER_CNN_LAYER, NONE, shaders/cnn/cnn_layer.wgsl, "CNN Layer Shader"
SHADER_CNN_V2_STATIC, NONE, shaders/cnn_v2_static.wgsl, "CNN v2 Static Features"
-SHADER_CNN_V2_LAYER_TEMPLATE, NONE, shaders/cnn_v2_layer_template.wgsl, "CNN v2 Layer Template"
+SHADER_CNN_V2_COMPUTE, NONE, shaders/cnn_v2_compute.wgsl, "CNN v2 Compute (Storage Buffer)"
+WEIGHTS_CNN_V2, NONE, cnn_v2_weights.bin, "CNN v2 Binary Weights"
SHADER_SOLARIZE, NONE, shaders/solarize.wgsl, "Solarize Shader"
SHADER_DISTORT, NONE, shaders/distort.wgsl, "Distort Shader"
SHADER_CHROMA_ABERRATION, NONE, shaders/chroma_aberration.wgsl, "Chroma Aberration Shader"
diff --git a/workspaces/main/shaders/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2_compute.wgsl
new file mode 100644
index 0000000..f9eb556
--- /dev/null
+++ b/workspaces/main/shaders/cnn_v2_compute.wgsl
@@ -0,0 +1,136 @@
+// CNN v2 Compute Shader - Storage Buffer Version
+// Processes single layer per dispatch with weights from storage buffer
+// Multi-layer execution handled by C++ with ping-pong buffers
+
+// Push constants for layer parameters (passed per dispatch)
+struct LayerParams {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in f16 units
+ is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu)
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features
+@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output (8-channel packed)
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output
+@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
+@group(0) @binding(4) var<uniform> params: LayerParams;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+// Get weight from storage buffer (f16 packed as u32 pairs)
+// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
+fn get_weight(idx: u32) -> f32 {
+ // Skip header (16 bytes = 4 u32) and layer info
+ // Weights start after header + layer_info, but weight_offset already accounts for this
+ let pair_idx = idx / 2u;
+ let packed = weights_buffer[pair_idx];
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ let kernel_size = params.kernel_size;
+ let in_channels = params.in_channels;
+ let out_channels = params.out_channels;
+ let weight_offset = params.weight_offset;
+ let is_output = params.is_output_layer != 0u;
+
+ let kernel_radius = i32(kernel_size / 2u);
+
+ // Load static features (always 8D)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution per output channel
+ var output: array<f32, 8>;
+ for (var c: u32 = 0u; c < out_channels && c < 8u; c++) {
+ var sum: f32 = 0.0;
+
+ // Convolve over kernel
+ for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
+ for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features at this spatial location
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + kernel_radius);
+ let kx_idx = u32(kx + kernel_radius);
+ let spatial_idx = ky_idx * kernel_size + kx_idx;
+
+ // Accumulate: static features (always 8 channels)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ i * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
+ }
+
+ // Accumulate: previous layer channels (in_channels - 8)
+ let prev_channels = in_channels - 8u;
+ for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ (8u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * layer_local[i];
+ }
+ }
+ }
+
+ // Activation
+ if (is_output) {
+ output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation
+ } else {
+ output[c] = max(0.0, sum); // ReLU
+ }
+ }
+
+ // Zero unused channels
+ for (var c: u32 = out_channels; c < 8u; c++) {
+ output[c] = 0.0;
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}