diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 10:27:50 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 10:27:50 +0100 |
| commit | e343021ac007549c76e58b27a361b11dd3f6a136 (patch) | |
| tree | a855b76dcc428752a09cbd192eabd16931baf804 /cnn_v3/training/export_cnn_v3_weights.py | |
| parent | 1e8ccfc67c264ce054c59257ee7c17ec4a584a9e (diff) | |
feat(cnn_v3): export script + HOW_TO_CNN.md playbook
- export_cnn_v3_weights.py: .pth → cnn_v3_weights.bin (f16 packed u32) + cnn_v3_film_mlp.bin (f32)
- HOW_TO_CNN.md: full pipeline playbook (data collection, training, export, C++ wiring, parity, HTML tool)
- TODO.md: mark export script done
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py new file mode 100644 index 0000000..a1ad42d --- /dev/null +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +"""Export trained CNN v3 weights → binary files for C++ runtime. + +Outputs +------- +<output_dir>/cnn_v3_weights.bin + Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. + Matches the format expected by CNNv3Effect::upload_weights(). + Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) + = 1964 f16 values = 982 u32 = 3928 bytes. + +<output_dir>/cnn_v3_film_mlp.bin + FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). + = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes. + For future CPU-side MLP inference in CNNv3Effect::set_film_params(). + +Usage +----- + cd cnn_v3/training + python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth + python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/ +""" + +import argparse +import struct +import sys +from pathlib import Path + +import numpy as np +import torch + +# Local import (same dir) +sys.path.insert(0, str(Path(__file__).parent)) +from train_cnn_v3 import CNNv3 + +# --------------------------------------------------------------------------- +# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py) +# --------------------------------------------------------------------------- +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296 +BN_WEIGHTS = 8 * 8 * 1 + 8 # 72 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292 +TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS +# = 1964 + + +def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: + """Pack flat f16 array as u32 pairs matching WGSL get_w() layout. + + WGSL get_w(buf, base, idx): + pair = buf[(base+idx)/2] + return f16 from low bits if even, high bits if odd. + So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc. + """ + f16 = w_f16.astype(np.float16) + if len(f16) % 2: + f16 = np.append(f16, np.float16(0)) + u16 = f16.view(np.uint16) + u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16) + return u32 + + +def extract_conv_layer(state: dict, name: str) -> np.ndarray: + """Extract conv weight (OIHW, flattened) + bias as f16 numpy array.""" + w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW + b = state[f"{name}.bias"].cpu().numpy().astype(np.float16) + return np.concatenate([w.flatten(), b.flatten()]) + + +def export_weights(checkpoint_path: str, output_dir: str) -> None: + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + cfg = ckpt.get('config', {}) + enc_channels = cfg.get('enc_channels', [4, 8]) + film_cond_dim = cfg.get('film_cond_dim', 5) + + model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) + model.load_state_dict(ckpt['model_state_dict']) + model.eval() + state = model.state_dict() + + epoch = ckpt.get('epoch', '?') + loss = ckpt.get('loss', float('nan')) + print(f"Checkpoint: epoch={epoch} loss={loss:.6f}") + print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}") + + # ----------------------------------------------------------------------- + # 1. CNN conv weights → cnn_v3_weights.bin + # ----------------------------------------------------------------------- + layers = [ + ('enc0', ENC0_WEIGHTS), + ('enc1', ENC1_WEIGHTS), + ('bottleneck', BN_WEIGHTS), + ('dec1', DEC1_WEIGHTS), + ('dec0', DEC0_WEIGHTS), + ] + + all_f16 = [] + for name, expected in layers: + chunk = extract_conv_layer(state, name) + if len(chunk) != expected: + raise ValueError( + f"{name}: expected {expected} f16 values, got {len(chunk)}") + all_f16.append(chunk) + + flat_f16 = np.concatenate(all_f16) + assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}" + + packed_u32 = pack_weights_u32(flat_f16) + weights_path = out / 'cnn_v3_weights.bin' + packed_u32.astype('<u4').tofile(weights_path) # little-endian u32 + print(f"\ncnn_v3_weights.bin") + print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes") + print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})") + + # ----------------------------------------------------------------------- + # 2. FiLM MLP weights → cnn_v3_film_mlp.bin (raw f32, row-major) + # ----------------------------------------------------------------------- + # film_mlp: Linear(film_cond_dim→16) ReLU Linear(16→film_out) + # State keys: film_mlp.0.weight (16, cond_dim), film_mlp.0.bias (16,) + # film_mlp.2.weight (film_out, 16), film_mlp.2.bias (film_out,) + mlp_pieces = [ + state['film_mlp.0.weight'].cpu().numpy().astype(np.float32).flatten(), + state['film_mlp.0.bias'].cpu().numpy().astype(np.float32).flatten(), + state['film_mlp.2.weight'].cpu().numpy().astype(np.float32).flatten(), + state['film_mlp.2.bias'].cpu().numpy().astype(np.float32).flatten(), + ] + mlp_f32 = np.concatenate(mlp_pieces) + mlp_path = out / 'cnn_v3_film_mlp.bin' + mlp_f32.astype('<f4').tofile(mlp_path) + + l0w = state['film_mlp.0.weight'].shape + l1w = state['film_mlp.2.weight'].shape + film_out = l1w[0] + print(f"\ncnn_v3_film_mlp.bin") + print(f" L0: weight {l0w} + bias ({l0w[0]},)") + print(f" L1: weight {l1w} + bias ({film_out},)") + print(f" {len(mlp_f32)} f32 values → {mlp_path.stat().st_size} bytes") + print(f" NOTE: future CPU MLP inference — feed [beat_phase, beat_norm,") + print(f" audio_intensity, style_p0, style_p1] → {film_out} outputs") + print(f" γ/β split: enc0({enc_channels[0]}×2) enc1({enc_channels[1]}×2)" + f" dec1({enc_channels[0]}×2) dec0(4×2)") + + print(f"\nDone → {out}/") + + +def main() -> None: + p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') + p.add_argument('checkpoint', help='Path to .pth checkpoint file') + p.add_argument('--output', default='export', + help='Output directory (default: export/)') + args = p.parse_args() + export_weights(args.checkpoint, args.output) + + +if __name__ == '__main__': + main() |
