summaryrefslogtreecommitdiff
path: root/cnn_v3/training/export_cnn_v3_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py44
1 files changed, 36 insertions, 8 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 99f3a81..78f5f25 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -15,8 +15,8 @@ Outputs
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
- Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292)
- = 1964 f16 values = 982 u32 = 3928 bytes.
+ Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
+ = 2476 f16 values = 1238 u32 = 4952 bytes.
<output_dir>/cnn_v3_film_mlp.bin
FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
@@ -31,6 +31,7 @@ Usage
"""
import argparse
+import base64
import struct
import sys
from pathlib import Path
@@ -47,13 +48,13 @@ from train_cnn_v3 import CNNv3
# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
-BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
+BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
-# = 1964
+# = 2476
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
@@ -158,13 +159,40 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
print(f"\nDone → {out}/")
+_WEIGHTS_JS_DEFAULT = Path(__file__).parent.parent / 'tools' / 'weights.js'
+
+
+def update_weights_js(weights_bin: Path, film_mlp_bin: Path,
+ js_path: Path = _WEIGHTS_JS_DEFAULT) -> None:
+ """Encode both .bin files as base64 and write cnn_v3/tools/weights.js."""
+ w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii')
+ f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii')
+ js_path.write_text(
+ "'use strict';\n"
+ "// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n"
+ f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n"
+ f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n"
+ )
+ print(f"\nweights.js → {js_path}")
+ print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)")
+ print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)")
+
+
def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
help='Output directory (default: export/)')
+ p.add_argument('--html', action='store_true',
+ help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights')
+ p.add_argument('--html-output', default=None, metavar='PATH',
+ help='Override default weights.js path (implies --html)')
args = p.parse_args()
export_weights(args.checkpoint, args.output)
+ if args.html or args.html_output:
+ out = Path(args.output)
+ js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT
+ update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path)
if __name__ == '__main__':