diff options
Diffstat (limited to 'training/export_cnn_v2_shader.py')
| -rwxr-xr-x | training/export_cnn_v2_shader.py | 214 |
1 files changed, 0 insertions, 214 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py deleted file mode 100755 index 1c74ad0..0000000 --- a/training/export_cnn_v2_shader.py +++ /dev/null @@ -1,214 +0,0 @@ -#!/usr/bin/env python3 -"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture - -Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. -Generates one shader per layer with embedded weight arrays. - -Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size. - This script is for debugging/testing with per-layer shaders. -""" - -import argparse -import numpy as np -import torch -from pathlib import Path - - -def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False): - """Generate WGSL compute shader for a single CNN layer. - - Args: - layer_idx: Layer index (0, 1, 2, ...) - weights: (4, 12, k, k) weight tensor (uniform 12D→4D) - kernel_size: Kernel size (3, 5, etc.) - output_dir: Output directory path - mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.) - is_output_layer: True if this is the final RGBA output layer - """ - weights_flat = weights.flatten() - weights_f16 = weights_flat.astype(np.float16) - weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals - - # Format weights as WGSL array - weights_str = ",\n ".join( - ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8]) - for i in range(0, len(weights_f32), 8) - ) - - radius = kernel_size // 2 - if is_output_layer: - activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer" - elif layer_idx == 0: - activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]" - else: - activation = "output[c] = max(0.0, sum); // Middle layers: ReLU" - - shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D) -// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D -// Mip level: {mip_level} (p0-p3 features) - -const KERNEL_SIZE: u32 = {kernel_size}u; -const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static) -const OUT_CHANNELS: u32 = 4u; // Uniform output -const KERNEL_RADIUS: i32 = {radius}; - -// Weights quantized to float16 (stored as f32 in WGSL) -const weights: array<f32, {len(weights_f32)}> = array( - {weights_str} -); - -@group(0) @binding(0) var static_features: texture_2d<u32>; -@group(0) @binding(1) var layer_input: texture_2d<u32>; -@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; - -fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{ - let packed = textureLoad(static_features, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); -}} - -fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{ - let packed = textureLoad(layer_input, coord, 0); - let v0 = unpack2x16float(packed.x); - let v1 = unpack2x16float(packed.y); - return vec4<f32>(v0.x, v0.y, v1.x, v1.y); -}} - -fn pack_channels(values: vec4<f32>) -> vec4<u32> {{ - return vec4<u32>( - pack2x16float(vec2<f32>(values.x, values.y)), - pack2x16float(vec2<f32>(values.z, values.w)), - 0u, // Unused - 0u // Unused - ); -}} - -@compute @workgroup_size(8, 8) -fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ - let coord = vec2<i32>(id.xy); - let dims = textureDimensions(static_features); - - if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{ - return; - }} - - // Load static features (always available) - let static_feat = unpack_static_features(coord); - - // Convolution: 12D input (4 prev + 8 static) → 4D output - var output: vec4<f32> = vec4<f32>(0.0); - for (var c: u32 = 0u; c < 4u; c++) {{ - var sum: f32 = 0.0; - - for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ - for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{ - let sample_coord = coord + vec2<i32>(kx, ky); - - // Border handling (clamp) - let clamped = vec2<i32>( - clamp(sample_coord.x, 0, i32(dims.x) - 1), - clamp(sample_coord.y, 0, i32(dims.y) - 1) - ); - - // Load features at this spatial location - let static_local = unpack_static_features(clamped); - let layer_local = unpack_layer_channels(clamped); // 4D - - // Weight index calculation - let ky_idx = u32(ky + KERNEL_RADIUS); - let kx_idx = u32(kx + KERNEL_RADIUS); - let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; - - // Accumulate: previous/input channels (4D) - for (var i: u32 = 0u; i < 4u; i++) {{ - let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + - i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * layer_local[i]; - }} - - // Accumulate: static features (8D) - for (var i: u32 = 0u; i < 8u; i++) {{ - let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + - (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * static_local[i]; - }} - }} - }} - - {activation} - }} - - // Pack and store - textureStore(output_tex, coord, pack_channels(output)); -}} -""" - - output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl" - output_path.write_text(shader_code) - print(f" → {output_path}") - - -def export_checkpoint(checkpoint_path, output_dir): - """Export PyTorch checkpoint to WGSL shaders. - - Args: - checkpoint_path: Path to .pth checkpoint - output_dir: Output directory for shaders - """ - print(f"Loading checkpoint: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location='cpu') - - state_dict = checkpoint['model_state_dict'] - config = checkpoint['config'] - - kernel_size = config.get('kernel_size', 3) - num_layers = config.get('num_layers', 3) - mip_level = config.get('mip_level', 0) - - print(f"Configuration:") - print(f" Kernel size: {kernel_size}×{kernel_size}") - print(f" Layers: {num_layers}") - print(f" Mip level: {mip_level} (p0-p3 features)") - print(f" Architecture: uniform 12D→4D") - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting shaders to {output_dir}/") - - # All layers uniform: 12D→4D - for i in range(num_layers): - layer_key = f'layers.{i}.weight' - if layer_key not in state_dict: - raise ValueError(f"Missing weights for layer {i}: {layer_key}") - - layer_weights = state_dict[layer_key].detach().numpy() - is_output = (i == num_layers - 1) - - export_layer_shader( - layer_idx=i, - weights=layer_weights, - kernel_size=kernel_size, - output_dir=output_dir, - mip_level=mip_level, - is_output_layer=is_output - ) - - print(f"\nExport complete! Generated {num_layers} shader files.") - - -def main(): - parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders') - parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file') - parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders', - help='Output directory for shaders') - - args = parser.parse_args() - export_checkpoint(args.checkpoint, args.output_dir) - - -if __name__ == '__main__': - main() |
