diff options
Diffstat (limited to 'training/export_cnn_v2_shader.py')
| -rwxr-xr-x | training/export_cnn_v2_shader.py | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py new file mode 100755 index 0000000..3c53ce2 --- /dev/null +++ b/training/export_cnn_v2_shader.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""CNN v2 Shader Export Script + +Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. +Generates one shader per layer with embedded weight arrays. +""" + +import argparse +import numpy as np +import torch +from pathlib import Path + + +def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels, + output_dir, is_output_layer=False): + """Generate WGSL compute shader for a single CNN layer. + + Args: + layer_idx: Layer index (0, 1, 2) + weights: (out_ch, in_ch, k, k) weight tensor + kernel_size: Kernel size (1, 3, 5, etc.) + in_channels: Input channels (includes 8D static features) + out_channels: Output channels + output_dir: Output directory path + is_output_layer: True if this is the final RGBA output layer + """ + weights_flat = weights.flatten() + weights_f16 = weights_flat.astype(np.float16) + weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals + + # Format weights as WGSL array + weights_str = ",\n ".join( + ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) + for i in range(0, len(weights_f32), 8) + ) + + radius = kernel_size // 2 + activation = "" if is_output_layer else "output[c] = max(0.0, sum); // ReLU" + if is_output_layer: + activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation" + + shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated +// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels} + +const KERNEL_SIZE: u32 = {kernel_size}u; +const IN_CHANNELS: u32 = {in_channels}u; +const OUT_CHANNELS: u32 = {out_channels}u; +const KERNEL_RADIUS: i32 = {radius}; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, {len(weights_f32)}> = array( + {weights_str} +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{ + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +}} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{ + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +}} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{ + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +}} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ + return; + }} + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{ + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load input features + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) {{ + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + }} + + // Accumulate: layer input channels (if layer_idx > 0) + let prev_channels = IN_CHANNELS - 8u; + for (var i: u32 = 0u; i < prev_channels; i++) {{ + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + }} + }} + }} + + {activation} + }} + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +}} +""" + + output_path = Path(output_dir) / f"cnn_v2_layer_{layer_idx}.wgsl" + output_path.write_text(shader_code) + print(f" → {output_path}") + + +def export_checkpoint(checkpoint_path, output_dir): + """Export PyTorch checkpoint to WGSL shaders. + + Args: + checkpoint_path: Path to .pth checkpoint + output_dir: Output directory for shaders + """ + print(f"Loading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + state_dict = checkpoint['model_state_dict'] + config = checkpoint['config'] + + print(f"Configuration:") + print(f" Kernels: {config['kernels']}") + print(f" Channels: {config['channels']}") + print(f" Features: {config['features']}") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nExporting shaders to {output_dir}/") + + # Layer 0: 8 → channels[0] + layer0_weights = state_dict['layer0.weight'].detach().numpy() + export_layer_shader( + layer_idx=0, + weights=layer0_weights, + kernel_size=config['kernels'][0], + in_channels=8, + out_channels=config['channels'][0], + output_dir=output_dir, + is_output_layer=False + ) + + # Layer 1: (8 + channels[0]) → channels[1] + layer1_weights = state_dict['layer1.weight'].detach().numpy() + export_layer_shader( + layer_idx=1, + weights=layer1_weights, + kernel_size=config['kernels'][1], + in_channels=8 + config['channels'][0], + out_channels=config['channels'][1], + output_dir=output_dir, + is_output_layer=False + ) + + # Layer 2: (8 + channels[1]) → 4 (RGBA) + layer2_weights = state_dict['layer2.weight'].detach().numpy() + export_layer_shader( + layer_idx=2, + weights=layer2_weights, + kernel_size=config['kernels'][2], + in_channels=8 + config['channels'][1], + out_channels=4, + output_dir=output_dir, + is_output_layer=True + ) + + print(f"\nExport complete! Generated 3 shader files.") + + +def main(): + parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') + parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') + parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders', + help='Output directory for shaders') + + args = parser.parse_args() + export_checkpoint(args.checkpoint, args.output_dir) + + +if __name__ == '__main__': + main() |
