#!/usr/bin/env python3 """CNN v2 Weight Export Script Converts PyTorch checkpoints to binary weight format for storage buffer. Exports single shader template + binary weights asset. """ import argparse import numpy as np import torch import struct from pathlib import Path def export_weights_binary(checkpoint_path, output_path): """Export CNN v2 weights to binary format. Binary format: Header (16 bytes): uint32 magic ('CNN2') uint32 version (1) uint32 num_layers uint32 total_weights (f16 count) LayerInfo × num_layers (20 bytes each): uint32 kernel_size uint32 in_channels uint32 out_channels uint32 weight_offset (f16 index) uint32 weight_count Weights (f16 array): float16[] all_weights Args: checkpoint_path: Path to .pth checkpoint output_path: Output .bin file path Returns: config dict for shader generation """ print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] print(f"Configuration:") print(f" Kernels: {config['kernels']}") print(f" Channels: {config['channels']}") # Collect layer info layers = [] all_weights = [] weight_offset = 0 # Layer 0: 8 → channels[0] layer0_weights = state_dict['layer0.weight'].detach().numpy() layer0_flat = layer0_weights.flatten() layers.append({ 'kernel_size': config['kernels'][0], 'in_channels': 8, 'out_channels': config['channels'][0], 'weight_offset': weight_offset, 'weight_count': len(layer0_flat) }) all_weights.extend(layer0_flat) weight_offset += len(layer0_flat) # Layer 1: (8 + channels[0]) → channels[1] layer1_weights = state_dict['layer1.weight'].detach().numpy() layer1_flat = layer1_weights.flatten() layers.append({ 'kernel_size': config['kernels'][1], 'in_channels': 8 + config['channels'][0], 'out_channels': config['channels'][1], 'weight_offset': weight_offset, 'weight_count': len(layer1_flat) }) all_weights.extend(layer1_flat) weight_offset += len(layer1_flat) # Layer 2: (8 + channels[1]) → 4 (RGBA output) layer2_weights = state_dict['layer2.weight'].detach().numpy() layer2_flat = layer2_weights.flatten() layers.append({ 'kernel_size': config['kernels'][2], 'in_channels': 8 + config['channels'][1], 'out_channels': 4, 'weight_offset': weight_offset, 'weight_count': len(layer2_flat) }) all_weights.extend(layer2_flat) weight_offset += len(layer2_flat) # Convert to f16 # TODO: Use 8-bit quantization for 2× size reduction # Requires quantization-aware training (QAT) to maintain accuracy all_weights_f16 = np.array(all_weights, dtype=np.float16) # Pack f16 pairs into u32 for storage buffer # Pad to even count if needed if len(all_weights_f16) % 2 == 1: all_weights_f16 = np.append(all_weights_f16, np.float16(0.0)) # Pack pairs using numpy view weights_u32 = all_weights_f16.view(np.uint32) print(f"\nWeight statistics:") print(f" Total layers: {len(layers)}") print(f" Total weights: {len(all_weights_f16)} (f16)") print(f" Packed: {len(weights_u32)} u32") print(f" Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes") # Write binary file output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'wb') as f: # Header (16 bytes) f.write(struct.pack('<4sIII', b'CNN2', # magic 1, # version len(layers), # num_layers len(all_weights_f16))) # total_weights (f16 count) # Layer info (20 bytes per layer) for layer in layers: f.write(struct.pack('; @group(0) @binding(1) var layer_input: texture_2d; @group(0) @binding(2) var output_tex: texture_storage_2d; @group(0) @binding(3) var weights: array; // Packed f16 pairs fn unpack_static_features(coord: vec2) -> array { let packed = textureLoad(static_features, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } fn unpack_layer_channels(coord: vec2) -> array { let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } fn pack_channels(values: array) -> vec4 { return vec4( pack2x16float(vec2(values[0], values[1])), pack2x16float(vec2(values[2], values[3])), pack2x16float(vec2(values[4], values[5])), pack2x16float(vec2(values[6], values[7])) ); } fn get_weight(idx: u32) -> f32 { let pair_idx = idx / 2u; let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32) let unpacked = unpack2x16float(packed); return select(unpacked.y, unpacked.x, (idx & 1u) == 0u); } @compute @workgroup_size(8, 8) fn main(@builtin(global_invocation_id) id: vec3) { let coord = vec2(id.xy); let dims = textureDimensions(static_features); if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; } // Read header let header_packed = weights[0]; // magic + version let counts_packed = weights[1]; // num_layers + total_weights let num_layers = counts_packed & 0xFFFFu; // Load static features let static_feat = unpack_static_features(coord); // Process each layer (hardcoded for 3 layers for now) // TODO: Dynamic layer loop when needed // Example for layer 0 - expand to full multi-layer when tested let layer_info_offset = 2u; // After header let layer0_info_base = layer_info_offset; // Read layer 0 info (5 u32 values = 20 bytes) let kernel_size = weights[layer0_info_base]; let in_channels = weights[layer0_info_base + 1u]; let out_channels = weights[layer0_info_base + 2u]; let weight_offset = weights[layer0_info_base + 3u]; // Convolution (simplified - expand to full kernel loop) var output: array; for (var c: u32 = 0u; c < min(out_channels, 8u); c++) { output[c] = 0.0; // TODO: Actual convolution } textureStore(output_tex, coord, pack_channels(output)); } """ output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl" output_path.write_text(shader_code) print(f" → {output_path}") def main(): parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format') parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin', help='Output binary weights file') parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders', help='Output directory for shader template') args = parser.parse_args() print("=== CNN v2 Weight Export ===\n") config = export_weights_binary(args.checkpoint, args.output_weights) print() # Shader is manually maintained in cnn_v2_compute.wgsl # export_shader_template(config, args.output_shader) print("\nExport complete!") if __name__ == '__main__': main()