summaryrefslogtreecommitdiff
path: root/cnn_v3/training/export_cnn_v3_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index a1ad42d..6d99af9 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -34,13 +34,15 @@ sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
-# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py)
+# Weight layout constants — must stay in sync with:
+# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
+# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296
-BN_WEIGHTS = 8 * 8 * 1 + 8 # 72
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
+BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
# = 1964