summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_weights.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
committerskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
commit161a59fa50bb92e3664c389fa03b95aefe349b3f (patch)
tree71548f64b2bdea958388f9063b74137659d70306 /training/export_cnn_v2_weights.py
parent9c3b72c710bf1ffa7e18f7c7390a425d57487eba (diff)
refactor(cnn): isolate CNN v2 to cnn_v2/ subdirectory
Move all CNN v2 files to dedicated cnn_v2/ directory to prepare for CNN v3 development. Zero functional changes. Structure: - cnn_v2/src/ - C++ effect implementation - cnn_v2/shaders/ - WGSL shaders (6 files) - cnn_v2/weights/ - Binary weights (3 files) - cnn_v2/training/ - Python training scripts (4 files) - cnn_v2/scripts/ - Shell scripts (train_cnn_v2_full.sh) - cnn_v2/tools/ - Validation tools (HTML) - cnn_v2/docs/ - Documentation (4 markdown files) Changes: - Update CMake source list to cnn_v2/src/cnn_v2_effect.cc - Update assets.txt with relative paths to cnn_v2/ - Update includes to ../../cnn_v2/src/cnn_v2_effect.h - Add PROJECT_ROOT resolution to Python/shell scripts - Update doc references in HOWTO.md, TODO.md - Add cnn_v2/README.md Verification: 34/34 tests passing, demo runs correctly. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/export_cnn_v2_weights.py')
-rwxr-xr-xtraining/export_cnn_v2_weights.py284
1 files changed, 0 insertions, 284 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
deleted file mode 100755
index f64bd8d..0000000
--- a/training/export_cnn_v2_weights.py
+++ /dev/null
@@ -1,284 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Weight Export Script
-
-Converts PyTorch checkpoints to binary weight format for storage buffer.
-Exports single shader template + binary weights asset.
-"""
-
-import argparse
-import numpy as np
-import torch
-import struct
-from pathlib import Path
-
-
-def export_weights_binary(checkpoint_path, output_path, quiet=False):
- """Export CNN v2 weights to binary format.
-
- Binary format:
- Header (20 bytes):
- uint32 magic ('CNN2')
- uint32 version (2)
- uint32 num_layers
- uint32 total_weights (f16 count)
- uint32 mip_level (0-3)
-
- LayerInfo × num_layers (20 bytes each):
- uint32 kernel_size
- uint32 in_channels
- uint32 out_channels
- uint32 weight_offset (f16 index)
- uint32 weight_count
-
- Weights (f16 array):
- float16[] all_weights
-
- Args:
- checkpoint_path: Path to .pth checkpoint
- output_path: Output .bin file path
-
- Returns:
- config dict for shader generation
- """
- if not quiet:
- print(f"Loading checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
-
- state_dict = checkpoint['model_state_dict']
- config = checkpoint['config']
-
- # Support both old (kernel_size) and new (kernel_sizes) format
- if 'kernel_sizes' in config:
- kernel_sizes = config['kernel_sizes']
- elif 'kernel_size' in config:
- kernel_size = config['kernel_size']
- num_layers = config.get('num_layers', 3)
- kernel_sizes = [kernel_size] * num_layers
- else:
- kernel_sizes = [3, 3, 3] # fallback
-
- num_layers = config.get('num_layers', len(kernel_sizes))
- mip_level = config.get('mip_level', 0)
-
- if not quiet:
- print(f"Configuration:")
- print(f" Kernel sizes: {kernel_sizes}")
- print(f" Layers: {num_layers}")
- print(f" Mip level: {mip_level} (p0-p3 features)")
- print(f" Architecture: uniform 12D→4D (bias=False)")
-
- # Collect layer info - all layers uniform 12D→4D
- layers = []
- all_weights = []
- weight_offset = 0
-
- for i in range(num_layers):
- layer_key = f'layers.{i}.weight'
- if layer_key not in state_dict:
- raise ValueError(f"Missing weights for layer {i}: {layer_key}")
-
- layer_weights = state_dict[layer_key].detach().numpy()
- layer_flat = layer_weights.flatten()
- kernel_size = kernel_sizes[i]
-
- layers.append({
- 'kernel_size': kernel_size,
- 'in_channels': 12, # 4 (input/prev) + 8 (static)
- 'out_channels': 4, # Uniform output
- 'weight_offset': weight_offset,
- 'weight_count': len(layer_flat)
- })
- all_weights.extend(layer_flat)
- weight_offset += len(layer_flat)
-
- if not quiet:
- print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights")
-
- # Convert to f16
- # TODO: Use 8-bit quantization for 2× size reduction
- # Requires quantization-aware training (QAT) to maintain accuracy
- all_weights_f16 = np.array(all_weights, dtype=np.float16)
-
- # Pack f16 pairs into u32 for storage buffer
- # Pad to even count if needed
- if len(all_weights_f16) % 2 == 1:
- all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))
-
- # Pack pairs using numpy view
- weights_u32 = all_weights_f16.view(np.uint32)
-
- binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4
- if not quiet:
- print(f"\nWeight statistics:")
- print(f" Total layers: {len(layers)}")
- print(f" Total weights: {len(all_weights_f16)} (f16)")
- print(f" Packed: {len(weights_u32)} u32")
- print(f" Binary size: {binary_size} bytes")
-
- # Write binary file
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(output_path, 'wb') as f:
- # Header (20 bytes) - version 2 with mip_level
- f.write(struct.pack('<4sIIII',
- b'CNN2', # magic
- 2, # version (bumped to 2)
- len(layers), # num_layers
- len(all_weights_f16), # total_weights (f16 count)
- mip_level)) # mip_level
-
- # Layer info (20 bytes per layer)
- for layer in layers:
- f.write(struct.pack('<IIIII',
- layer['kernel_size'],
- layer['in_channels'],
- layer['out_channels'],
- layer['weight_offset'],
- layer['weight_count']))
-
- # Weights (u32 packed f16 pairs)
- f.write(weights_u32.tobytes())
-
- if quiet:
- print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}")
- else:
- print(f" → {output_path}")
-
- return {
- 'num_layers': len(layers),
- 'layers': layers
- }
-
-
-def export_shader_template(config, output_dir):
- """Generate single WGSL shader template with storage buffer binding.
-
- Args:
- config: Layer configuration from export_weights_binary()
- output_dir: Output directory path
- """
- shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
-// Reads weights from storage buffer, processes all layers in sequence
-
-struct CNNv2Header {
- magic: u32, // 'CNN2'
- version: u32, // 1
- num_layers: u32, // Number of layers
- total_weights: u32, // Total f16 weight count
-}
-
-struct CNNv2LayerInfo {
- kernel_size: u32,
- in_channels: u32,
- out_channels: u32,
- weight_offset: u32, // Offset in weights array
- weight_count: u32,
-}
-
-@group(0) @binding(0) var static_features: texture_2d<u32>;
-@group(0) @binding(1) var layer_input: texture_2d<u32>;
-@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
-@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs
-
-fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
- let packed = textureLoad(static_features, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
-}
-
-fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
- let packed = textureLoad(layer_input, coord, 0);
- let v0 = unpack2x16float(packed.x);
- let v1 = unpack2x16float(packed.y);
- return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
-}
-
-fn pack_channels(values: vec4<f32>) -> vec4<u32> {
- return vec4<u32>(
- pack2x16float(vec2<f32>(values.x, values.y)),
- pack2x16float(vec2<f32>(values.z, values.w)),
- 0u, // Unused
- 0u // Unused
- );
-}
-
-fn get_weight(idx: u32) -> f32 {
- let pair_idx = idx / 2u;
- let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32)
- let unpacked = unpack2x16float(packed);
- return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
-}
-
-@compute @workgroup_size(8, 8)
-fn main(@builtin(global_invocation_id) id: vec3<u32>) {
- let coord = vec2<i32>(id.xy);
- let dims = textureDimensions(static_features);
-
- if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
- return;
- }
-
- // Read header
- let header_packed = weights[0]; // magic + version
- let counts_packed = weights[1]; // num_layers + total_weights
- let num_layers = counts_packed & 0xFFFFu;
-
- // Load static features
- let static_feat = unpack_static_features(coord);
-
- // Process each layer (hardcoded for 3 layers for now)
- // TODO: Dynamic layer loop when needed
-
- // Example for layer 0 - expand to full multi-layer when tested
- let layer_info_offset = 2u; // After header
- let layer0_info_base = layer_info_offset;
-
- // Read layer 0 info (5 u32 values = 20 bytes)
- let kernel_size = weights[layer0_info_base];
- let in_channels = weights[layer0_info_base + 1u];
- let out_channels = weights[layer0_info_base + 2u];
- let weight_offset = weights[layer0_info_base + 3u];
-
- // Convolution: 12D input (4 prev + 8 static) → 4D output
- var output: vec4<f32> = vec4<f32>(0.0);
- for (var c: u32 = 0u; c < 4u; c++) {
- output[c] = 0.0; // TODO: Actual convolution
- }
-
- textureStore(output_tex, coord, pack_channels(output));
-}
-"""
-
- output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl"
- output_path.write_text(shader_code)
- print(f" → {output_path}")
-
-
-def main():
- parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
- parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
- parser.add_argument('--output-weights', type=str, default='workspaces/main/weights/cnn_v2_weights.bin',
- help='Output binary weights file')
- parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
- help='Output directory for shader template')
- parser.add_argument('--quiet', action='store_true',
- help='Suppress detailed output')
-
- args = parser.parse_args()
-
- if not args.quiet:
- print("=== CNN v2 Weight Export ===\n")
- config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet)
- if not args.quiet:
- print()
- # Shader is manually maintained in cnn_v2_compute.wgsl
- # export_shader_template(config, args.output_shader)
- print("\nExport complete!")
-
-
-if __name__ == '__main__':
- main()