#!/usr/bin/env python3 # /// script # requires-python = ">=3.11" # dependencies = [ # "numpy", # "opencv-python", # "pillow", # "torch", # ] # /// """Export trained CNN v3 weights → binary files for C++ runtime. Outputs ------- /cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) = 1964 f16 values = 982 u32 = 3928 bytes. /cnn_v3_film_mlp.bin FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes. For future CPU-side MLP inference in CNNv3Effect::set_film_params(). Usage ----- cd cnn_v3/training python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/ """ import argparse import struct import sys from pathlib import Path import numpy as np import torch # Local import (same dir) sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- # Weight layout constants — must stay in sync with: # cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) # cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS # = 1964 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: """Pack flat f16 array as u32 pairs matching WGSL get_w() layout. WGSL get_w(buf, base, idx): pair = buf[(base+idx)/2] return f16 from low bits if even, high bits if odd. So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc. """ f16 = w_f16.astype(np.float16) if len(f16) % 2: f16 = np.append(f16, np.float16(0)) u16 = f16.view(np.uint16) u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16) return u32 def extract_conv_layer(state: dict, name: str) -> np.ndarray: """Extract conv weight (OIHW, flattened) + bias as f16 numpy array.""" w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW b = state[f"{name}.bias"].cpu().numpy().astype(np.float16) return np.concatenate([w.flatten(), b.flatten()]) def export_weights(checkpoint_path: str, output_dir: str) -> None: out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) cfg = ckpt.get('config', {}) enc_channels = cfg.get('enc_channels', [4, 8]) film_cond_dim = cfg.get('film_cond_dim', 5) model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) model.load_state_dict(ckpt['model_state_dict']) model.eval() state = model.state_dict() epoch = ckpt.get('epoch', '?') loss = ckpt.get('loss', float('nan')) print(f"Checkpoint: epoch={epoch} loss={loss:.6f}") print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}") # ----------------------------------------------------------------------- # 1. CNN conv weights → cnn_v3_weights.bin # ----------------------------------------------------------------------- layers = [ ('enc0', ENC0_WEIGHTS), ('enc1', ENC1_WEIGHTS), ('bottleneck', BN_WEIGHTS), ('dec1', DEC1_WEIGHTS), ('dec0', DEC0_WEIGHTS), ] all_f16 = [] for name, expected in layers: chunk = extract_conv_layer(state, name) if len(chunk) != expected: raise ValueError( f"{name}: expected {expected} f16 values, got {len(chunk)}") all_f16.append(chunk) flat_f16 = np.concatenate(all_f16) assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}" packed_u32 = pack_weights_u32(flat_f16) weights_path = out / 'cnn_v3_weights.bin' packed_u32.astype(' None: p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') p.add_argument('checkpoint', help='Path to .pth checkpoint file') p.add_argument('--output', default='export', help='Output directory (default: export/)') args = p.parse_args() export_weights(args.checkpoint, args.output) if __name__ == '__main__': main()