1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
|
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "numpy",
# "opencv-python",
# "pillow",
# "torch",
# ]
# ///
"""Export trained CNN v3 weights → binary files for C++ runtime.
Outputs
-------
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
= 2476 f16 values = 1238 u32 = 4952 bytes.
<output_dir>/cnn_v3_film_mlp.bin
FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
= 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes.
For future CPU-side MLP inference in CNNv3Effect::set_film_params().
Usage
-----
cd cnn_v3/training
python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth
python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth --output /tmp/out/
"""
import argparse
import base64
import struct
import sys
from pathlib import Path
import numpy as np
import torch
# Local import (same dir)
sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
# Weight layout constants — must stay in sync with:
# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
# = 2476
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
"""Pack flat f16 array as u32 pairs matching WGSL get_w() layout.
WGSL get_w(buf, base, idx):
pair = buf[(base+idx)/2]
return f16 from low bits if even, high bits if odd.
So w[0] in bits [15:0] of u32[0], w[1] in bits [31:16] of u32[0], etc.
"""
f16 = w_f16.astype(np.float16)
if len(f16) % 2:
f16 = np.append(f16, np.float16(0))
u16 = f16.view(np.uint16)
u32 = u16[0::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16)
return u32
def extract_conv_layer(state: dict, name: str) -> np.ndarray:
"""Extract conv weight (OIHW, flattened) + bias as f16 numpy array."""
w = state[f"{name}.weight"].cpu().numpy().astype(np.float16) # OIHW
b = state[f"{name}.bias"].cpu().numpy().astype(np.float16)
return np.concatenate([w.flatten(), b.flatten()])
def export_weights(checkpoint_path: str, output_dir: str) -> None:
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
cfg = ckpt.get('config', {})
enc_channels = cfg.get('enc_channels', [4, 8])
film_cond_dim = cfg.get('film_cond_dim', 5)
model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
state = model.state_dict()
epoch = ckpt.get('epoch', '?')
loss = ckpt.get('loss', float('nan'))
print(f"Checkpoint: epoch={epoch} loss={loss:.6f}")
print(f" enc_channels={enc_channels} film_cond_dim={film_cond_dim}")
# -----------------------------------------------------------------------
# 1. CNN conv weights → cnn_v3_weights.bin
# -----------------------------------------------------------------------
layers = [
('enc0', ENC0_WEIGHTS),
('enc1', ENC1_WEIGHTS),
('bottleneck', BN_WEIGHTS),
('dec1', DEC1_WEIGHTS),
('dec0', DEC0_WEIGHTS),
]
all_f16 = []
for name, expected in layers:
chunk = extract_conv_layer(state, name)
if len(chunk) != expected:
raise ValueError(
f"{name}: expected {expected} f16 values, got {len(chunk)}")
all_f16.append(chunk)
flat_f16 = np.concatenate(all_f16)
assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}"
packed_u32 = pack_weights_u32(flat_f16)
weights_path = out / 'cnn_v3_weights.bin'
packed_u32.astype('<u4').tofile(weights_path) # little-endian u32
print(f"\ncnn_v3_weights.bin")
print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})")
# -----------------------------------------------------------------------
# 2. FiLM MLP weights → cnn_v3_film_mlp.bin (raw f32, row-major)
# -----------------------------------------------------------------------
# film_mlp: Linear(film_cond_dim→16) ReLU Linear(16→film_out)
# State keys: film_mlp.0.weight (16, cond_dim), film_mlp.0.bias (16,)
# film_mlp.2.weight (film_out, 16), film_mlp.2.bias (film_out,)
mlp_pieces = [
state['film_mlp.0.weight'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.0.bias'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.2.weight'].cpu().numpy().astype(np.float32).flatten(),
state['film_mlp.2.bias'].cpu().numpy().astype(np.float32).flatten(),
]
mlp_f32 = np.concatenate(mlp_pieces)
mlp_path = out / 'cnn_v3_film_mlp.bin'
mlp_f32.astype('<f4').tofile(mlp_path)
l0w = state['film_mlp.0.weight'].shape
l1w = state['film_mlp.2.weight'].shape
film_out = l1w[0]
print(f"\ncnn_v3_film_mlp.bin")
print(f" L0: weight {l0w} + bias ({l0w[0]},)")
print(f" L1: weight {l1w} + bias ({film_out},)")
print(f" {len(mlp_f32)} f32 values → {mlp_path.stat().st_size} bytes")
print(f" NOTE: future CPU MLP inference — feed [beat_phase, beat_norm,")
print(f" audio_intensity, style_p0, style_p1] → {film_out} outputs")
print(f" γ/β split: enc0({enc_channels[0]}×2) enc1({enc_channels[1]}×2)"
f" dec1({enc_channels[0]}×2) dec0(4×2)")
print(f"\nDone → {out}/")
_WEIGHTS_JS_DEFAULT = Path(__file__).parent.parent / 'tools' / 'weights.js'
def update_weights_js(weights_bin: Path, film_mlp_bin: Path,
js_path: Path = _WEIGHTS_JS_DEFAULT) -> None:
"""Encode both .bin files as base64 and write cnn_v3/tools/weights.js."""
w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii')
f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii')
js_path.write_text(
"'use strict';\n"
"// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n"
f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n"
f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n"
)
print(f"\nweights.js → {js_path}")
print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)")
print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)")
def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
help='Output directory (default: export/)')
p.add_argument('--html', action='store_true',
help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights')
p.add_argument('--html-output', default=None, metavar='PATH',
help='Override default weights.js path (implies --html)')
args = p.parse_args()
export_weights(args.checkpoint, args.output)
if args.html or args.html_output:
out = Path(args.output)
js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT
update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path)
if __name__ == '__main__':
main()
|