summaryrefslogtreecommitdiff
path: root/cnn_v3/training/export_cnn_v3_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py51
1 files changed, 29 insertions, 22 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 78f5f25..2fa83d1 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -15,12 +15,12 @@ Outputs
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
- Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
- = 2476 f16 values = 1238 u32 = 4952 bytes.
+ Layout: enc0 (1448) | enc1 (1168) | bottleneck (2320) | dec1 (2312) | dec0 (580)
+ = 7828 f16 values = 3914 u32 = 15656 bytes.
<output_dir>/cnn_v3_film_mlp.bin
- FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
- = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes.
+ FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×72) L1_b (72).
+ = 5*16 + 16 + 16*72 + 72 = 80 + 16 + 1152 + 72 = 1320 f32 = 5280 bytes.
For future CPU-side MLP inference in CNNv3Effect::set_film_params().
Usage
@@ -44,17 +44,19 @@ sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
-# Weight layout constants — must stay in sync with:
-# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
-# cnn_v3/training/gen_test_vectors.py (same constants)
+# Weight layout helpers — derived from enc_channels at runtime.
+# Must stay in sync with cnn_v3/src/cnn_v3_effect.cc and gen_test_vectors.py.
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
-BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
-TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
-# = 2476
+N_IN = 20 # feature input channels (fixed)
+
+def weight_counts(enc_channels):
+ c0, c1 = enc_channels
+ enc0 = N_IN * c0 * 9 + c0
+ enc1 = c0 * c1 * 9 + c1
+ bn = c1 * c1 * 9 + c1
+ dec1 = (c1 * 2) * c0 * 9 + c0
+ dec0 = (c0 * 2) * 4 * 9 + 4
+ return enc0, enc1, bn, dec1, dec0
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
@@ -86,7 +88,7 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
cfg = ckpt.get('config', {})
- enc_channels = cfg.get('enc_channels', [4, 8])
+ enc_channels = cfg.get('enc_channels', [8, 16])
film_cond_dim = cfg.get('film_cond_dim', 5)
model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
@@ -102,13 +104,18 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
# -----------------------------------------------------------------------
# 1. CNN conv weights → cnn_v3_weights.bin
# -----------------------------------------------------------------------
+ enc0_w, enc1_w, bn_w, dec1_w, dec0_w = weight_counts(enc_channels)
+ total_f16 = enc0_w + enc1_w + bn_w + dec1_w + dec0_w
layers = [
- ('enc0', ENC0_WEIGHTS),
- ('enc1', ENC1_WEIGHTS),
- ('bottleneck', BN_WEIGHTS),
- ('dec1', DEC1_WEIGHTS),
- ('dec0', DEC0_WEIGHTS),
+ ('enc0', enc0_w),
+ ('enc1', enc1_w),
+ ('bottleneck', bn_w),
+ ('dec1', dec1_w),
+ ('dec0', dec0_w),
]
+ print(f" Weight layout: enc0={enc0_w} enc1={enc1_w} bn={bn_w} "
+ f"dec1={dec1_w} dec0={dec0_w} total={total_f16} f16 "
+ f"({total_f16*2/1024:.1f} KB)")
all_f16 = []
for name, expected in layers:
@@ -119,13 +126,13 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
all_f16.append(chunk)
flat_f16 = np.concatenate(all_f16)
- assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}"
+ assert len(flat_f16) == total_f16, f"total mismatch: {len(flat_f16)} != {total_f16}"
packed_u32 = pack_weights_u32(flat_f16)
weights_path = out / 'cnn_v3_weights.bin'
packed_u32.astype('<u4').tofile(weights_path) # little-endian u32
print(f"\ncnn_v3_weights.bin")
- print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
+ print(f" {total_f16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})")
# -----------------------------------------------------------------------