#!/usr/bin/env python3 # /// script # requires-python = ">=3.11" # dependencies = [ # "numpy", # "opencv-python", # "pillow", # "torch", # ] # /// """Export trained CNN v3 weights → binary files for C++ runtime. Outputs ------- /cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). Layout: enc0 (1448) | enc1 (1168) | bottleneck (2320) | dec1 (2312) | dec0 (580) = 7828 f16 values = 3914 u32 = 15656 bytes. /cnn_v3_film_mlp.bin FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×72) L1_b (72). = 5*16 + 16 + 16*72 + 72 = 80 + 16 + 1152 + 72 = 1320 f32 = 5280 bytes. For future CPU-side MLP inference in CNNv3Effect::set_film_params(). Usage ----- cd cnn_v3/training python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/ """ import argparse import base64 import struct import sys from pathlib import Path import numpy as np import torch # Local import (same dir) sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- # Weight layout helpers — derived from enc_channels at runtime. # Must stay in sync with cnn_v3/src/cnn_v3_effect.cc and gen_test_vectors.py. # --------------------------------------------------------------------------- N_IN = 20 # feature input channels (fixed) def weight_counts(enc_channels): c0, c1 = enc_channels enc0 = N_IN * c0 * 9 + c0 enc1 = c0 * c1 * 9 + c1 bn = c1 * c1 * 9 + c1 dec1 = (c1 * 2) * c0 * 9 + c0 dec0 = (c0 * 2) * 4 * 9 + 4 return enc0, enc1, bn, dec1, dec0 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: """Pack flat f16 array as u32 pairs matching WGSL get_w() layout. WGSL get_w(buf, base, idx): pair = buf[(base+idx)/2] return f16 from low bits if even, high bits if odd. So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc. """ f16 = w_f16.astype(np.float16) if len(f16) % 2: f16 = np.append(f16, np.float16(0)) u16 = f16.view(np.uint16) u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16) return u32 def extract_conv_layer(state: dict, name: str) -> np.ndarray: """Extract conv weight (OIHW, flattened) + bias as f16 numpy array.""" w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW b = state[f"{name}.bias"].cpu().numpy().astype(np.float16) return np.concatenate([w.flatten(), b.flatten()]) def export_weights(checkpoint_path: str, output_dir: str) -> None: out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) cfg = ckpt.get('config', {}) enc_channels = cfg.get('enc_channels', [8, 16]) film_cond_dim = cfg.get('film_cond_dim', 5) model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) model.load_state_dict(ckpt['model_state_dict']) model.eval() state = model.state_dict() epoch = ckpt.get('epoch', '?') loss = ckpt.get('loss', float('nan')) print(f"Checkpoint: epoch={epoch} loss={loss:.6f}") print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}") # ----------------------------------------------------------------------- # 1. CNN conv weights → cnn_v3_weights.bin # ----------------------------------------------------------------------- enc0_w, enc1_w, bn_w, dec1_w, dec0_w = weight_counts(enc_channels) total_f16 = enc0_w + enc1_w + bn_w + dec1_w + dec0_w layers = [ ('enc0', enc0_w), ('enc1', enc1_w), ('bottleneck', bn_w), ('dec1', dec1_w), ('dec0', dec0_w), ] print(f" Weight layout: enc0={enc0_w} enc1={enc1_w} bn={bn_w} " f"dec1={dec1_w} dec0={dec0_w} total={total_f16} f16 " f"({total_f16*2/1024:.1f} KB)") all_f16 = [] for name, expected in layers: chunk = extract_conv_layer(state, name) if len(chunk) != expected: raise ValueError( f"{name}: expected {expected} f16 values, got {len(chunk)}") all_f16.append(chunk) flat_f16 = np.concatenate(all_f16) assert len(flat_f16) == total_f16, f"total mismatch: {len(flat_f16)} != {total_f16}" packed_u32 = pack_weights_u32(flat_f16) weights_path = out / 'cnn_v3_weights.bin' packed_u32.astype(' None: """Encode both .bin files as base64 and write cnn_v3/tools/weights.js.""" w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii') f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii') js_path.write_text( "'use strict';\n" "// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n" f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n" f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n" ) print(f"\nweights.js → {js_path}") print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)") print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)") def main() -> None: p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') p.add_argument('checkpoint', help='Path to .pth checkpoint file') p.add_argument('--output', default='export', help='Output directory (default: export)') p.add_argument('--html', action='store_true', help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights') p.add_argument('--html-output', default=None, metavar='PATH', help='Override default weights.js path (implies --html)') args = p.parse_args() export_weights(args.checkpoint, args.output) if args.html or args.html_output: out = Path(args.output) js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path) if __name__ == '__main__': main()