#!/usr/bin/env python3 """CNN v2 Shader Export Script - Uniform 12D→4D Architecture Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. Generates one shader per layer with embedded weight arrays. Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size. This script is for debugging/testing with per-layer shaders. """ import argparse import numpy as np import torch from pathlib import Path def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False): """Generate WGSL compute shader for a single CNN layer. Args: layer_idx: Layer index (0, 1, 2, ...) weights: (4, 12, k, k) weight tensor (uniform 12D→4D) kernel_size: Kernel size (3, 5, etc.) output_dir: Output directory path mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.) is_output_layer: True if this is the final RGBA output layer """ weights_flat = weights.flatten() weights_f16 = weights_flat.astype(np.float16) weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals # Format weights as WGSL array weights_str = ",\n ".join( ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) for i in range(0, len(weights_f32), 8) ) radius = kernel_size // 2 if is_output_layer: activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer" elif layer_idx == 0: activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]" else: activation = "output[c] = max(0.0, sum); // Middle layers: ReLU" shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D) // Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D // Mip level: {mip_level} (p0-p3 features) const KERNEL_SIZE: u32 = {kernel_size}u; const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static) const OUT_CHANNELS: u32 = 4u; // Uniform output const KERNEL_RADIUS: i32 = {radius}; // Weights quantized to float16 (stored as f32 in WGSL) const weights: array = array( {weights_str} ); @group(0) @binding(0) var static_features: texture_2d; @group(0) @binding(1) var layer_input: texture_2d; @group(0) @binding(2) var output_tex: texture_storage_2d; fn unpack_static_features(coord: vec2) -> array {{ let packed = textureLoad(static_features, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); }} fn unpack_layer_channels(coord: vec2) -> vec4 {{ let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); return vec4(v0.x, v0.y, v1.x, v1.y); }} fn pack_channels(values: vec4) -> vec4 {{ return vec4( pack2x16float(vec2(values.x, values.y)), pack2x16float(vec2(values.z, values.w)), 0u, // Unused 0u // Unused ); }} @compute @workgroup_size(8, 8) fn main(@builtin(global_invocation_id) id: vec3) {{ let coord = vec2(id.xy); let dims = textureDimensions(static_features); if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ return; }} // Load static features (always available) let static_feat = unpack_static_features(coord); // Convolution: 12D input (4 prev + 8 static) → 4D output var output: vec4 = vec4(0.0); for (var c: u32 = 0u; c < 4u; c++) {{ var sum: f32 = 0.0; for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ let sample_coord = coord + vec2(kx, ky); // Border handling (clamp) let clamped = vec2( clamp(sample_coord.x, 0, i32(dims.x) - 1), clamp(sample_coord.y, 0, i32(dims.y) - 1) ); // Load features at this spatial location let static_local = unpack_static_features(clamped); let layer_local = unpack_layer_channels(clamped); // 4D // Weight index calculation let ky_idx = u32(ky + KERNEL_RADIUS); let kx_idx = u32(kx + KERNEL_RADIUS); let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; // Accumulate: previous/input channels (4D) for (var i: u32 = 0u; i < 4u; i++) {{ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; sum += weights[w_idx] * layer_local[i]; }} // Accumulate: static features (8D) for (var i: u32 = 0u; i < 8u; i++) {{ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; sum += weights[w_idx] * static_local[i]; }} }} }} {activation} }} // Pack and store textureStore(output_tex, coord, pack_channels(output)); }} """ output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl" output_path.write_text(shader_code) print(f" → {output_path}") def export_checkpoint(checkpoint_path, output_dir): """Export PyTorch checkpoint to WGSL shaders. Args: checkpoint_path: Path to .pth checkpoint output_dir: Output directory for shaders """ print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] kernel_size = config.get('kernel_size', 3) num_layers = config.get('num_layers', 3) mip_level = config.get('mip_level', 0) print(f"Configuration:") print(f" Kernel size: {kernel_size}×{kernel_size}") print(f" Layers: {num_layers}") print(f" Mip level: {mip_level} (p0-p3 features)") print(f" Architecture: uniform 12D→4D") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"\nExporting shaders to {output_dir}/") # All layers uniform: 12D→4D for i in range(num_layers): layer_key = f'layers.{i}.weight' if layer_key not in state_dict: raise ValueError(f"Missing weights for layer {i}: {layer_key}") layer_weights = state_dict[layer_key].detach().numpy() is_output = (i == num_layers - 1) export_layer_shader( layer_idx=i, weights=layer_weights, kernel_size=kernel_size, output_dir=output_dir, mip_level=mip_level, is_output_layer=is_output ) print(f"\nExport complete! Generated {num_layers} shader files.") def main(): parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders', help='Output directory for shaders') args = parser.parse_args() export_checkpoint(args.checkpoint, args.output_dir) if __name__ == '__main__': main()