diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 12:08:22 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 12:08:22 +0100 |
| commit | 4d87a6d781c3f159d216f4cd9251e3d7bd63554f (patch) | |
| tree | 61bb4ee18b1c981cee789b215adf73860138d6c2 /training | |
| parent | 4cbf571a0087020bedf3c565483f94bc795ed4c4 (diff) | |
CNN v2: storage buffer architecture foundation
- Add binary weight format (header + layer info + packed f16)
- New export_cnn_v2_weights.py for binary weight export
- Single cnn_v2_compute.wgsl shader with storage buffer
- Load weights in CNNv2Effect::load_weights()
- Create layer compute pipeline with 5 bindings
- Fast training config: 100 epochs, 3×3 kernels, 8→4→4 channels
Next: Complete bind group creation and multi-layer compute execution
Diffstat (limited to 'training')
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py new file mode 100755 index 0000000..05d4958 --- /dev/null +++ b/training/export_cnn_v2_weights.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +"""CNN v2 Weight Export Script + +Converts PyTorch checkpoints to binary weight format for storage buffer. +Exports single shader template + binary weights asset. +""" + +import argparse +import numpy as np +import torch +import struct +from pathlib import Path + + +def export_weights_binary(checkpoint_path, output_path): + """Export CNN v2 weights to binary format. + + Binary format: + Header (16 bytes): + uint32 magic ('CNN2') + uint32 version (1) + uint32 num_layers + uint32 total_weights (f16 count) + + LayerInfo × num_layers (20 bytes each): + uint32 kernel_size + uint32 in_channels + uint32 out_channels + uint32 weight_offset (f16 index) + uint32 weight_count + + Weights (f16 array): + float16[] all_weights + + Args: + checkpoint_path: Path to .pth checkpoint + output_path: Output .bin file path + + Returns: + config dict for shader generation + """ + print(f"Loading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + state_dict = checkpoint['model_state_dict'] + config = checkpoint['config'] + + print(f"Configuration:") + print(f" Kernels: {config['kernels']}") + print(f" Channels: {config['channels']}") + + # Collect layer info + layers = [] + all_weights = [] + weight_offset = 0 + + # Layer 0: 8 → channels[0] + layer0_weights = state_dict['layer0.weight'].detach().numpy() + layer0_flat = layer0_weights.flatten() + layers.append({ + 'kernel_size': config['kernels'][0], + 'in_channels': 8, + 'out_channels': config['channels'][0], + 'weight_offset': weight_offset, + 'weight_count': len(layer0_flat) + }) + all_weights.extend(layer0_flat) + weight_offset += len(layer0_flat) + + # Layer 1: (8 + channels[0]) → channels[1] + layer1_weights = state_dict['layer1.weight'].detach().numpy() + layer1_flat = layer1_weights.flatten() + layers.append({ + 'kernel_size': config['kernels'][1], + 'in_channels': 8 + config['channels'][0], + 'out_channels': config['channels'][1], + 'weight_offset': weight_offset, + 'weight_count': len(layer1_flat) + }) + all_weights.extend(layer1_flat) + weight_offset += len(layer1_flat) + + # Layer 2: (8 + channels[1]) → 4 (RGBA output) + layer2_weights = state_dict['layer2.weight'].detach().numpy() + layer2_flat = layer2_weights.flatten() + layers.append({ + 'kernel_size': config['kernels'][2], + 'in_channels': 8 + config['channels'][1], + 'out_channels': 4, + 'weight_offset': weight_offset, + 'weight_count': len(layer2_flat) + }) + all_weights.extend(layer2_flat) + weight_offset += len(layer2_flat) + + # Convert to f16 + all_weights_f16 = np.array(all_weights, dtype=np.float16) + + # Pack f16 pairs into u32 for storage buffer + # Pad to even count if needed + if len(all_weights_f16) % 2 == 1: + all_weights_f16 = np.append(all_weights_f16, np.float16(0.0)) + + # Pack pairs using numpy view + weights_u32 = all_weights_f16.view(np.uint32) + + print(f"\nWeight statistics:") + print(f" Total layers: {len(layers)}") + print(f" Total weights: {len(all_weights_f16)} (f16)") + print(f" Packed: {len(weights_u32)} u32") + print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes") + + # Write binary file + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'wb') as f: + # Header (16 bytes) + f.write(struct.pack('<4sIII', + b'CNN2', # magic + 1, # version + len(layers), # num_layers + len(all_weights_f16))) # total_weights (f16 count) + + # Layer info (20 bytes per layer) + for layer in layers: + f.write(struct.pack('<IIIII', + layer['kernel_size'], + layer['in_channels'], + layer['out_channels'], + layer['weight_offset'], + layer['weight_count'])) + + # Weights (u32 packed f16 pairs) + f.write(weights_u32.tobytes()) + + print(f" → {output_path}") + + return { + 'num_layers': len(layers), + 'layers': layers + } + + +def export_shader_template(config, output_dir): + """Generate single WGSL shader template with storage buffer binding. + + Args: + config: Layer configuration from export_weights_binary() + output_dir: Output directory path + """ + shader_code = """// CNN v2 Compute Shader - Storage Buffer Version +// Reads weights from storage buffer, processes all layers in sequence + +struct CNNv2Header { + magic: u32, // 'CNN2' + version: u32, // 1 + num_layers: u32, // Number of layers + total_weights: u32, // Total f16 weight count +} + +struct CNNv2LayerInfo { + kernel_size: u32, + in_channels: u32, + out_channels: u32, + weight_offset: u32, // Offset in weights array + weight_count: u32, +} + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; +@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +} + +fn get_weight(idx: u32) -> f32 { + let pair_idx = idx / 2u; + let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32) + let unpacked = unpack2x16float(packed); + return select(unpacked.y, unpacked.x, (idx & 1u) == 0u); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Read header + let header_packed = weights[0]; // magic + version + let counts_packed = weights[1]; // num_layers + total_weights + let num_layers = counts_packed & 0xFFFFu; + + // Load static features + let static_feat = unpack_static_features(coord); + + // Process each layer (hardcoded for 3 layers for now) + // TODO: Dynamic layer loop when needed + + // Example for layer 0 - expand to full multi-layer when tested + let layer_info_offset = 2u; // After header + let layer0_info_base = layer_info_offset; + + // Read layer 0 info (5 u32 values = 20 bytes) + let kernel_size = weights[layer0_info_base]; + let in_channels = weights[layer0_info_base + 1u]; + let out_channels = weights[layer0_info_base + 2u]; + let weight_offset = weights[layer0_info_base + 3u]; + + // Convolution (simplified - expand to full kernel loop) + var output: array<f32, 8>; + for (var c: u32 = 0u; c < min(out_channels, 8u); c++) { + output[c] = 0.0; // TODO: Actual convolution + } + + textureStore(output_tex, coord, pack_channels(output)); +} +""" + + output_path = Path(output_dir) / "cnn_v2_compute.wgsl" + output_path.write_text(shader_code) + print(f" → {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format') + parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') + parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin', + help='Output binary weights file') + parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders', + help='Output directory for shader template') + + args = parser.parse_args() + + print("=== CNN v2 Weight Export ===\n") + config = export_weights_binary(args.checkpoint, args.output_weights) + print() + export_shader_template(config, args.output_shader) + print("\nExport complete!") + + +if __name__ == '__main__': + main() |
