#!/usr/bin/env python3 """CNN v2 Shader Export Script Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. Generates one shader per layer with embedded weight arrays. """ import argparse import numpy as np import torch from pathlib import Path def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels, output_dir, is_output_layer=False): """Generate WGSL compute shader for a single CNN layer. Args: layer_idx: Layer index (0, 1, 2) weights: (out_ch, in_ch, k, k) weight tensor kernel_size: Kernel size (1, 3, 5, etc.) in_channels: Input channels (includes 8D static features) out_channels: Output channels output_dir: Output directory path is_output_layer: True if this is the final RGBA output layer """ weights_flat = weights.flatten() weights_f16 = weights_flat.astype(np.float16) weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals # Format weights as WGSL array weights_str = ",\n ".join( ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) for i in range(0, len(weights_f32), 8) ) radius = kernel_size // 2 activation = "" if is_output_layer else "output[c] = max(0.0, sum); // ReLU" if is_output_layer: activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation" shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated // Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels} const KERNEL_SIZE: u32 = {kernel_size}u; const IN_CHANNELS: u32 = {in_channels}u; const OUT_CHANNELS: u32 = {out_channels}u; const KERNEL_RADIUS: i32 = {radius}; // Weights quantized to float16 (stored as f32 in WGSL) const weights: array = array( {weights_str} ); @group(0) @binding(0) var static_features: texture_2d; @group(0) @binding(1) var layer_input: texture_2d; @group(0) @binding(2) var output_tex: texture_storage_2d; fn unpack_static_features(coord: vec2) -> array {{ let packed = textureLoad(static_features, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); }} fn unpack_layer_channels(coord: vec2) -> array {{ let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); }} fn pack_channels(values: array) -> vec4 {{ return vec4( pack2x16float(vec2(values[0], values[1])), pack2x16float(vec2(values[2], values[3])), pack2x16float(vec2(values[4], values[5])), pack2x16float(vec2(values[6], values[7])) ); }} @compute @workgroup_size(8, 8) fn main(@builtin(global_invocation_id) id: vec3) {{ let coord = vec2(id.xy); let dims = textureDimensions(static_features); if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ return; }} // Load static features (always available) let static_feat = unpack_static_features(coord); // Convolution var output: array; for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{ var sum: f32 = 0.0; for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ let sample_coord = coord + vec2(kx, ky); // Border handling (clamp) let clamped = vec2( clamp(sample_coord.x, 0, i32(dims.x) - 1), clamp(sample_coord.y, 0, i32(dims.y) - 1) ); // Load input features let static_local = unpack_static_features(clamped); let layer_local = unpack_layer_channels(clamped); // Weight index calculation let ky_idx = u32(ky + KERNEL_RADIUS); let kx_idx = u32(kx + KERNEL_RADIUS); let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; // Accumulate: static features (8D) for (var i: u32 = 0u; i < 8u; i++) {{ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; sum += weights[w_idx] * static_local[i]; }} // Accumulate: layer input channels (if layer_idx > 0) let prev_channels = IN_CHANNELS - 8u; for (var i: u32 = 0u; i < prev_channels; i++) {{ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; sum += weights[w_idx] * layer_local[i]; }} }} }} {activation} }} // Pack and store textureStore(output_tex, coord, pack_channels(output)); }} """ output_path = Path(output_dir) / f"cnn_v2_layer_{layer_idx}.wgsl" output_path.write_text(shader_code) print(f" → {output_path}") def export_checkpoint(checkpoint_path, output_dir): """Export PyTorch checkpoint to WGSL shaders. Args: checkpoint_path: Path to .pth checkpoint output_dir: Output directory for shaders """ print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] print(f"Configuration:") print(f" Kernels: {config['kernels']}") print(f" Channels: {config['channels']}") print(f" Features: {config['features']}") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"\nExporting shaders to {output_dir}/") # Layer 0: 8 → channels[0] layer0_weights = state_dict['layer0.weight'].detach().numpy() export_layer_shader( layer_idx=0, weights=layer0_weights, kernel_size=config['kernels'][0], in_channels=8, out_channels=config['channels'][0], output_dir=output_dir, is_output_layer=False ) # Layer 1: (8 + channels[0]) → channels[1] layer1_weights = state_dict['layer1.weight'].detach().numpy() export_layer_shader( layer_idx=1, weights=layer1_weights, kernel_size=config['kernels'][1], in_channels=8 + config['channels'][0], out_channels=config['channels'][1], output_dir=output_dir, is_output_layer=False ) # Layer 2: (8 + channels[1]) → 4 (RGBA) layer2_weights = state_dict['layer2.weight'].detach().numpy() export_layer_shader( layer_idx=2, weights=layer2_weights, kernel_size=config['kernels'][2], in_channels=8 + config['channels'][1], out_channels=4, output_dir=output_dir, is_output_layer=True ) print(f"\nExport complete! Generated 3 shader files.") def main(): parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders', help='Output directory for shaders') args = parser.parse_args() export_checkpoint(args.checkpoint, args.output_dir) if __name__ == '__main__': main()