summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
new file mode 100644
index 0000000..a1ad42d
--- /dev/null
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -0,0 +1,160 @@
+#!/usr/bin/env python3
+"""Export trained CNN v3 weights → binary files for C++ runtime.
+
+Outputs
+-------
+<output_dir>/cnn_v3_weights.bin
+ Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
+ Matches the format expected by CNNv3Effect::upload_weights().
+ Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292)
+ = 1964 f16 values = 982 u32 = 3928 bytes.
+
+<output_dir>/cnn_v3_film_mlp.bin
+ FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
+ = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes.
+ For future CPU-side MLP inference in CNNv3Effect::set_film_params().
+
+Usage
+-----
+ cd cnn_v3/training
+ python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth
+ python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/
+"""
+
+import argparse
+import struct
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+
+# Local import (same dir)
+sys.path.insert(0, str(Path(__file__).parent))
+from train_cnn_v3 import CNNv3
+
+# ---------------------------------------------------------------------------
+# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py)
+# ---------------------------------------------------------------------------
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296
+BN_WEIGHTS = 8 * 8 * 1 + 8 # 72
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292
+TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
+# = 1964
+
+
+def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
+ """Pack flat f16 array as u32 pairs matching WGSL get_w() layout.
+
+ WGSL get_w(buf, base, idx):
+ pair = buf[(base+idx)/2]
+ return f16 from low bits if even, high bits if odd.
+ So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc.
+ """
+ f16 = w_f16.astype(np.float16)
+ if len(f16) % 2:
+ f16 = np.append(f16, np.float16(0))
+ u16 = f16.view(np.uint16)
+ u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16)
+ return u32
+
+
+def extract_conv_layer(state: dict, name: str) -> np.ndarray:
+ """Extract conv weight (OIHW, flattened) + bias as f16 numpy array."""
+ w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW
+ b = state[f"{name}.bias"].cpu().numpy().astype(np.float16)
+ return np.concatenate([w.flatten(), b.flatten()])
+
+
+def export_weights(checkpoint_path: str, output_dir: str) -> None:
+ out = Path(output_dir)
+ out.mkdir(parents=True, exist_ok=True)
+
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+ cfg = ckpt.get('config', {})
+ enc_channels = cfg.get('enc_channels', [4, 8])
+ film_cond_dim = cfg.get('film_cond_dim', 5)
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
+ model.load_state_dict(ckpt['model_state_dict'])
+ model.eval()
+ state = model.state_dict()
+
+ epoch = ckpt.get('epoch', '?')
+ loss = ckpt.get('loss', float('nan'))
+ print(f"Checkpoint: epoch={epoch} loss={loss:.6f}")
+ print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}")
+
+ # -----------------------------------------------------------------------
+ # 1. CNN conv weights → cnn_v3_weights.bin
+ # -----------------------------------------------------------------------
+ layers = [
+ ('enc0', ENC0_WEIGHTS),
+ ('enc1', ENC1_WEIGHTS),
+ ('bottleneck', BN_WEIGHTS),
+ ('dec1', DEC1_WEIGHTS),
+ ('dec0', DEC0_WEIGHTS),
+ ]
+
+ all_f16 = []
+ for name, expected in layers:
+ chunk = extract_conv_layer(state, name)
+ if len(chunk) != expected:
+ raise ValueError(
+ f"{name}: expected {expected} f16 values, got {len(chunk)}")
+ all_f16.append(chunk)
+
+ flat_f16 = np.concatenate(all_f16)
+ assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}"
+
+ packed_u32 = pack_weights_u32(flat_f16)
+ weights_path = out / 'cnn_v3_weights.bin'
+ packed_u32.astype('<u4').tofile(weights_path) # little-endian u32
+ print(f"\ncnn_v3_weights.bin")
+ print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
+ print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})")
+
+ # -----------------------------------------------------------------------
+ # 2. FiLM MLP weights → cnn_v3_film_mlp.bin (raw f32, row-major)
+ # -----------------------------------------------------------------------
+ # film_mlp: Linear(film_cond_dim→16) ReLU Linear(16→film_out)
+ # State keys: film_mlp.0.weight (16, cond_dim), film_mlp.0.bias (16,)
+ # film_mlp.2.weight (film_out, 16), film_mlp.2.bias (film_out,)
+ mlp_pieces = [
+ state['film_mlp.0.weight'].cpu().numpy().astype(np.float32).flatten(),
+ state['film_mlp.0.bias'].cpu().numpy().astype(np.float32).flatten(),
+ state['film_mlp.2.weight'].cpu().numpy().astype(np.float32).flatten(),
+ state['film_mlp.2.bias'].cpu().numpy().astype(np.float32).flatten(),
+ ]
+ mlp_f32 = np.concatenate(mlp_pieces)
+ mlp_path = out / 'cnn_v3_film_mlp.bin'
+ mlp_f32.astype('<f4').tofile(mlp_path)
+
+ l0w = state['film_mlp.0.weight'].shape
+ l1w = state['film_mlp.2.weight'].shape
+ film_out = l1w[0]
+ print(f"\ncnn_v3_film_mlp.bin")
+ print(f" L0: weight {l0w} + bias ({l0w[0]},)")
+ print(f" L1: weight {l1w} + bias ({film_out},)")
+ print(f" {len(mlp_f32)} f32 values → {mlp_path.stat().st_size} bytes")
+ print(f" NOTE: future CPU MLP inference — feed [beat_phase, beat_norm,")
+ print(f" audio_intensity, style_p0, style_p1] → {film_out} outputs")
+ print(f" γ/β split: enc0({enc_channels[0]}×2) enc1({enc_channels[1]}×2)"
+ f" dec1({enc_channels[0]}×2) dec0(4×2)")
+
+ print(f"\nDone → {out}/")
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
+ p.add_argument('checkpoint', help='Path to .pth checkpoint file')
+ p.add_argument('--output', default='export',
+ help='Output directory (default: export/)')
+ args = p.parse_args()
+ export_weights(args.checkpoint, args.output)
+
+
+if __name__ == '__main__':
+ main()