summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_shader.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
committerskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
commit161a59fa50bb92e3664c389fa03b95aefe349b3f (patch)
tree71548f64b2bdea958388f9063b74137659d70306 /training/export_cnn_v2_shader.py
parent9c3b72c710bf1ffa7e18f7c7390a425d57487eba (diff)
refactor(cnn): isolate CNN v2 to cnn_v2/ subdirectory
Move all CNN v2 files to dedicated cnn_v2/ directory to prepare for CNN v3 development. Zero functional changes. Structure: - cnn_v2/src/ - C++ effect implementation - cnn_v2/shaders/ - WGSL shaders (6 files) - cnn_v2/weights/ - Binary weights (3 files) - cnn_v2/training/ - Python training scripts (4 files) - cnn_v2/scripts/ - Shell scripts (train_cnn_v2_full.sh) - cnn_v2/tools/ - Validation tools (HTML) - cnn_v2/docs/ - Documentation (4 markdown files) Changes: - Update CMake source list to cnn_v2/src/cnn_v2_effect.cc - Update assets.txt with relative paths to cnn_v2/ - Update includes to ../../cnn_v2/src/cnn_v2_effect.h - Add PROJECT_ROOT resolution to Python/shell scripts - Update doc references in HOWTO.md, TODO.md - Add cnn_v2/README.md Verification: 34/34 tests passing, demo runs correctly. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/export_cnn_v2_shader.py')
-rwxr-xr-xtraining/export_cnn_v2_shader.py214
1 files changed, 0 insertions, 214 deletions
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
deleted file mode 100755
index 1c74ad0..0000000
--- a/training/export_cnn_v2_shader.py
+++ /dev/null
@@ -1,214 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
-
-Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
-Generates one shader per layer with embedded weight arrays.
-
-Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
- This script is for debugging/testing with per-layer shaders.
-"""
-
-import argparse
-import numpy as np
-import torch
-from pathlib import Path
-
-
-def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False):
- """Generate WGSL compute shader for a single CNN layer.
-
- Args:
- layer_idx: Layer index (0, 1, 2, ...)
- weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
- kernel_size: Kernel size (3, 5, etc.)
- output_dir: Output directory path
- mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.)
- is_output_layer: True if this is the final RGBA output layer
- """
- weights_flat = weights.flatten()
- weights_f16 = weights_flat.astype(np.float16)
- weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals
-
- # Format weights as WGSL array
- weights_str = ",\n ".join(
- ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
- for i in range(0, len(weights_f32), 8)
- )
-
- radius = kernel_size // 2
- if is_output_layer:
- activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer"
- elif layer_idx == 0:
- activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]"
- else:
- activation = "output[c] = max(0.0, sum); // Middle layers: ReLU"
-
- shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
-// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
-// Mip level: {mip_level} (p0-p3 features)
-
-const KERNEL_SIZE: u32 = {kernel_size}u;
-const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
-const OUT_CHANNELS: u32 = 4u; // Uniform output
-const KERNEL_RADIUS: i32 = {radius};
-
-// Weights quantized to float16 (stored as f32 in WGSL)
-const weights: array<f32, {len(weights_f32)}> = array(
- {weights_str}
-);
-
-@group(0) @binding(0) var static_features: texture_2d<u32>;
-@group(0) @binding(1) var layer_input: texture_2d<u32>;
-@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
-
-fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
- let packed = textureLoad(static_features, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
-}}
-
-fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
- let packed = textureLoad(layer_input, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
-}}
-
-fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
- return vec4<u32>(
- pack2x16float(vec2<f32>(values.x, values.y)),
- pack2x16float(vec2<f32>(values.z, values.w)),
- 0u, // Unused
- 0u // Unused
- );
-}}
-
-@compute @workgroup_size(8, 8)
-fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
- let coord = vec2<i32>(id.xy);
- let dims = textureDimensions(static_features);
-
- if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
- return;
- }}
-
- // Load static features (always available)
- let static_feat = unpack_static_features(coord);
-
- // Convolution: 12D input (4 prev + 8 static) → 4D output
- var output: vec4<f32> = vec4<f32>(0.0);
- for (var c: u32 = 0u; c < 4u; c++) {{
- var sum: f32 = 0.0;
-
- for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
- for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
- let sample_coord = coord + vec2<i32>(kx, ky);
-
- // Border handling (clamp)
- let clamped = vec2<i32>(
- clamp(sample_coord.x, 0, i32(dims.x) - 1),
- clamp(sample_coord.y, 0, i32(dims.y) - 1)
- );
-
- // Load features at this spatial location
- let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped); // 4D
-
- // Weight index calculation
- let ky_idx = u32(ky + KERNEL_RADIUS);
- let kx_idx = u32(kx + KERNEL_RADIUS);
- let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
-
- // Accumulate: previous/input channels (4D)
- for (var i: u32 = 0u; i < 4u; i++) {{
- let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
- i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * layer_local[i];
- }}
-
- // Accumulate: static features (8D)
- for (var i: u32 = 0u; i < 8u; i++) {{
- let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
- (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * static_local[i];
- }}
- }}
- }}
-
- {activation}
- }}
-
- // Pack and store
- textureStore(output_tex, coord, pack_channels(output));
-}}
-"""
-
- output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl"
- output_path.write_text(shader_code)
- print(f" → {output_path}")
-
-
-def export_checkpoint(checkpoint_path, output_dir):
- """Export PyTorch checkpoint to WGSL shaders.
-
- Args:
- checkpoint_path: Path to .pth checkpoint
- output_dir: Output directory for shaders
- """
- print(f"Loading checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- state_dict = checkpoint['model_state_dict']
- config = checkpoint['config']
-
- kernel_size = config.get('kernel_size', 3)
- num_layers = config.get('num_layers', 3)
- mip_level = config.get('mip_level', 0)
-
- print(f"Configuration:")
- print(f" Kernel size: {kernel_size}×{kernel_size}")
- print(f" Layers: {num_layers}")
- print(f" Mip level: {mip_level} (p0-p3 features)")
- print(f" Architecture: uniform 12D→4D")
-
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- print(f"\nExporting shaders to {output_dir}/")
-
- # All layers uniform: 12D→4D
- for i in range(num_layers):
- layer_key = f'layers.{i}.weight'
- if layer_key not in state_dict:
- raise ValueError(f"Missing weights for layer {i}: {layer_key}")
-
- layer_weights = state_dict[layer_key].detach().numpy()
- is_output = (i == num_layers - 1)
-
- export_layer_shader(
- layer_idx=i,
- weights=layer_weights,
- kernel_size=kernel_size,
- output_dir=output_dir,
- mip_level=mip_level,
- is_output_layer=is_output
- )
-
- print(f"\nExport complete! Generated {num_layers} shader files.")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
- parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
- parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders',
- help='Output directory for shaders')
-
- args = parser.parse_args()
- export_checkpoint(args.checkpoint, args.output_dir)
-
-
-if __name__ == '__main__':
- main()