From e343021ac007549c76e58b27a361b11dd3f6a136 Mon Sep 17 00:00:00 2001 From: skal Date: Sat, 21 Mar 2026 10:27:50 +0100 Subject: feat(cnn_v3): export script + HOW_TO_CNN.md playbook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - export_cnn_v3_weights.py: .pth → cnn_v3_weights.bin (f16 packed u32) + cnn_v3_film_mlp.bin (f32) - HOW_TO_CNN.md: full pipeline playbook (data collection, training, export, C++ wiring, parity, HTML tool) - TODO.md: mark export script done Co-Authored-By: Claude Sonnet 4.6 --- cnn_v3/training/export_cnn_v3_weights.py | 160 +++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 cnn_v3/training/export_cnn_v3_weights.py (limited to 'cnn_v3/training/export_cnn_v3_weights.py') diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py new file mode 100644 index 0000000..a1ad42d --- /dev/null +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +"""Export trained CNN v3 weights → binary files for C++ runtime. + +Outputs +------- +/cnn_v3_weights.bin + Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. + Matches the format expected by CNNv3Effect::upload_weights(). + Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) + = 1964 f16 values = 982 u32 = 3928 bytes. + +/cnn_v3_film_mlp.bin + FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). + = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes. + For future CPU-side MLP inference in CNNv3Effect::set_film_params(). + +Usage +----- + cd cnn_v3/training + python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth + python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/ +""" + +import argparse +import struct +import sys +from pathlib import Path + +import numpy as np +import torch + +# Local import (same dir) +sys.path.insert(0, str(Path(__file__).parent)) +from train_cnn_v3 import CNNv3 + +# --------------------------------------------------------------------------- +# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py) +# --------------------------------------------------------------------------- +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296 +BN_WEIGHTS = 8 * 8 * 1 + 8 # 72 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292 +TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS +# = 1964 + + +def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: + """Pack flat f16 array as u32 pairs matching WGSL get_w() layout. + + WGSL get_w(buf, base, idx): + pair = buf[(base+idx)/2] + return f16 from low bits if even, high bits if odd. + So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc. + """ + f16 = w_f16.astype(np.float16) + if len(f16) % 2: + f16 = np.append(f16, np.float16(0)) + u16 = f16.view(np.uint16) + u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16) + return u32 + + +def extract_conv_layer(state: dict, name: str) -> np.ndarray: + """Extract conv weight (OIHW, flattened) + bias as f16 numpy array.""" + w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW + b = state[f"{name}.bias"].cpu().numpy().astype(np.float16) + return np.concatenate([w.flatten(), b.flatten()]) + + +def export_weights(checkpoint_path: str, output_dir: str) -> None: + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + cfg = ckpt.get('config', {}) + enc_channels = cfg.get('enc_channels', [4, 8]) + film_cond_dim = cfg.get('film_cond_dim', 5) + + model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) + model.load_state_dict(ckpt['model_state_dict']) + model.eval() + state = model.state_dict() + + epoch = ckpt.get('epoch', '?') + loss = ckpt.get('loss', float('nan')) + print(f"Checkpoint: epoch={epoch} loss={loss:.6f}") + print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}") + + # ----------------------------------------------------------------------- + # 1. CNN conv weights → cnn_v3_weights.bin + # ----------------------------------------------------------------------- + layers = [ + ('enc0', ENC0_WEIGHTS), + ('enc1', ENC1_WEIGHTS), + ('bottleneck', BN_WEIGHTS), + ('dec1', DEC1_WEIGHTS), + ('dec0', DEC0_WEIGHTS), + ] + + all_f16 = [] + for name, expected in layers: + chunk = extract_conv_layer(state, name) + if len(chunk) != expected: + raise ValueError( + f"{name}: expected {expected} f16 values, got {len(chunk)}") + all_f16.append(chunk) + + flat_f16 = np.concatenate(all_f16) + assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}" + + packed_u32 = pack_weights_u32(flat_f16) + weights_path = out / 'cnn_v3_weights.bin' + packed_u32.astype(' None: + p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') + p.add_argument('checkpoint', help='Path to .pth checkpoint file') + p.add_argument('--output', default='export', + help='Output directory (default: export/)') + args = p.parse_args() + export_weights(args.checkpoint, args.output) + + +if __name__ == '__main__': + main() -- cgit v1.2.3