summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_shader.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/export_cnn_v2_shader.py')
-rwxr-xr-xtraining/export_cnn_v2_shader.py225
1 files changed, 225 insertions, 0 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
new file mode 100755
index 0000000..3c53ce2
--- /dev/null
+++ b/training/export_cnn_v2_shader.py
@@ -0,0 +1,225 @@
+#!/usr/bin/env python3
+"""CNN v2 Shader Export Script
+
+Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
+Generates one shader per layer with embedded weight arrays.
+"""
+
+import argparse
+import numpy as np
+import torch
+from pathlib import Path
+
+
+def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels,
+ output_dir, is_output_layer=False):
+ """Generate WGSL compute shader for a single CNN layer.
+
+ Args:
+ layer_idx: Layer index (0, 1, 2)
+ weights: (out_ch, in_ch, k, k) weight tensor
+ kernel_size: Kernel size (1, 3, 5, etc.)
+ in_channels: Input channels (includes 8D static features)
+ out_channels: Output channels
+ output_dir: Output directory path
+ is_output_layer: True if this is the final RGBA output layer
+ """
+ weights_flat = weights.flatten()
+ weights_f16 = weights_flat.astype(np.float16)
+ weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals
+
+ # Format weights as WGSL array
+ weights_str = ",\n ".join(
+ ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
+ for i in range(0, len(weights_f32), 8)
+ )
+
+ radius = kernel_size // 2
+ activation = "" if is_output_layer else "output[c] = max(0.0, sum); // ReLU"
+ if is_output_layer:
+ activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation"
+
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated
+// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels}
+
+const KERNEL_SIZE: u32 = {kernel_size}u;
+const IN_CHANNELS: u32 = {in_channels}u;
+const OUT_CHANNELS: u32 = {out_channels}u;
+const KERNEL_RADIUS: i32 = {radius};
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, {len(weights_f32)}> = array(
+ {weights_str}
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
+ return;
+ }}
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }}
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {{
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }}
+ }}
+ }}
+
+ {activation}
+ }}
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}}
+"""
+
+ output_path = Path(output_dir) / f"cnn_v2_layer_{layer_idx}.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def export_checkpoint(checkpoint_path, output_dir):
+ """Export PyTorch checkpoint to WGSL shaders.
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_dir: Output directory for shaders
+ """
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ print(f"Configuration:")
+ print(f" Kernels: {config['kernels']}")
+ print(f" Channels: {config['channels']}")
+ print(f" Features: {config['features']}")
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nExporting shaders to {output_dir}/")
+
+ # Layer 0: 8 → channels[0]
+ layer0_weights = state_dict['layer0.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=0,
+ weights=layer0_weights,
+ kernel_size=config['kernels'][0],
+ in_channels=8,
+ out_channels=config['channels'][0],
+ output_dir=output_dir,
+ is_output_layer=False
+ )
+
+ # Layer 1: (8 + channels[0]) → channels[1]
+ layer1_weights = state_dict['layer1.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=1,
+ weights=layer1_weights,
+ kernel_size=config['kernels'][1],
+ in_channels=8 + config['channels'][0],
+ out_channels=config['channels'][1],
+ output_dir=output_dir,
+ is_output_layer=False
+ )
+
+ # Layer 2: (8 + channels[1]) → 4 (RGBA)
+ layer2_weights = state_dict['layer2.weight'].detach().numpy()
+ export_layer_shader(
+ layer_idx=2,
+ weights=layer2_weights,
+ kernel_size=config['kernels'][2],
+ in_channels=8 + config['channels'][1],
+ out_channels=4,
+ output_dir=output_dir,
+ is_output_layer=True
+ )
+
+ print(f"\nExport complete! Generated 3 shader files.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders',
+ help='Output directory for shaders')
+
+ args = parser.parse_args()
+ export_checkpoint(args.checkpoint, args.output_dir)
+
+
+if __name__ == '__main__':
+ main()