1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
|
#!/usr/bin/env python3
"""Export trained CNN v3 weights → binary files for C++ runtime.
Outputs
-------
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292)
= 1964 f16 values = 982 u32 = 3928 bytes.
<output_dir>/cnn_v3_film_mlp.bin
FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
= 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes.
For future CPU-side MLP inference in CNNv3Effect::set_film_params().
Usage
-----
cd cnn_v3/training
python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth
python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/
"""
import argparse
import struct
import sys
from pathlib import Path
import numpy as np
import torch
# Local import (same dir)
sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py)
# ---------------------------------------------------------------------------
ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724
ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296
BN_WEIGHTS = 8 * 8 * 1 + 8 # 72
DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580
DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
# = 1964
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
"""Pack flat f16 array as u32 pairs matching WGSL get_w() layout.
WGSL get_w(buf, base, idx):
pair = buf[(base+idx)/2]
return f16 from low bits if even, high bits if odd.
So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc.
"""
f16 = w_f16.astype(np.float16)
if len(f16) % 2:
f16 = np.append(f16, np.float16(0))
u16 = f16.view(np.uint16)
u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16)
return u32
def extract_conv_layer(state: dict, name: str) -> np.ndarray:
"""Extract conv weight (OIHW, flattened) + bias as f16 numpy array."""
w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW
b = state[f"{name}.bias"].cpu().numpy().astype(np.float16)
return np.concatenate([w.flatten(), b.flatten()])
def export_weights(checkpoint_path: str, output_dir: str) -> None:
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
cfg = ckpt.get('config', {})
enc_channels = cfg.get('enc_channels', [4, 8])
film_cond_dim = cfg.get('film_cond_dim', 5)
model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
state = model.state_dict()
epoch = ckpt.get('epoch', '?')
loss = ckpt.get('loss', float('nan'))
print(f"Checkpoint: epoch={epoch} loss={loss:.6f}")
print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}")
# -----------------------------------------------------------------------
# 1. CNN conv weights → cnn_v3_weights.bin
# -----------------------------------------------------------------------
layers = [
('enc0', ENC0_WEIGHTS),
('enc1', ENC1_WEIGHTS),
('bottleneck', BN_WEIGHTS),
('dec1', DEC1_WEIGHTS),
('dec0', DEC0_WEIGHTS),
]
all_f16 = []
for name, expected in layers:
chunk = extract_conv_layer(state, name)
if len(chunk) != expected:
raise ValueError(
f"{name}: expected {expected} f16 values, got {len(chunk)}")
all_f16.append(chunk)
flat_f16 = np.concatenate(all_f16)
assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}"
packed_u32 = pack_weights_u32(flat_f16)
weights_path = out / 'cnn_v3_weights.bin'
packed_u32.astype('<u4').tofile(weights_path) # little-endian u32
print(f"\ncnn_v3_weights.bin")
print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})")
# -----------------------------------------------------------------------
# 2. FiLM MLP weights → cnn_v3_film_mlp.bin (raw f32, row-major)
# -----------------------------------------------------------------------
# film_mlp: Linear(film_cond_dim→16) ReLU Linear(16→film_out)
# State keys: film_mlp.0.weight (16, cond_dim), film_mlp.0.bias (16,)
# film_mlp.2.weight (film_out, 16), film_mlp.2.bias (film_out,)
mlp_pieces = [
state['film_mlp.0.weight'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.0.bias'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.2.weight'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.2.bias'].cpu().numpy().astype(np.float32).flatten(),
]
mlp_f32 = np.concatenate(mlp_pieces)
mlp_path = out / 'cnn_v3_film_mlp.bin'
mlp_f32.astype('<f4').tofile(mlp_path)
l0w = state['film_mlp.0.weight'].shape
l1w = state['film_mlp.2.weight'].shape
film_out = l1w[0]
print(f"\ncnn_v3_film_mlp.bin")
print(f" L0: weight {l0w} + bias ({l0w[0]},)")
print(f" L1: weight {l1w} + bias ({film_out},)")
print(f" {len(mlp_f32)} f32 values → {mlp_path.stat().st_size} bytes")
print(f" NOTE: future CPU MLP inference — feed [beat_phase, beat_norm,")
print(f" audio_intensity, style_p0, style_p1] → {film_out} outputs")
print(f" γ/β split: enc0({enc_channels[0]}×2) enc1({enc_channels[1]}×2)"
f" dec1({enc_channels[0]}×2) dec0(4×2)")
print(f"\nDone → {out}/")
def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
help='Output directory (default: export/)')
args = p.parse_args()
export_weights(args.checkpoint, args.output)
if __name__ == '__main__':
main()
|