summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:10:19 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:10:19 +0100
commit130d70ea130242533a0fd3a7ffeabdb68598c88e (patch)
treec5f89a8cf6f9695b1799bb41c18a4bab7a2e9ad2
parent8f14bdd66cb002b2f89265b2a578ad93249089c9 (diff)
fix(cnn_v3/training): fix defaults and help strings across py tools
- train_cnn_v3.py: enc-channels 4,8→8,16; checkpoint-every 50→100; add help strings for epochs/batch-size/lr/checkpoint-dir - gen_test_vectors.py: add help strings for --W/--H/--seed - export_cnn_v3_weights.py: fix --output help string (export/→export)
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py2
-rw-r--r--cnn_v3/training/gen_test_vectors.py9
-rw-r--r--cnn_v3/training/train_cnn_v3.py21
3 files changed, 20 insertions, 12 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 2fa83d1..67330a6 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -189,7 +189,7 @@ def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
- help='Output directory (default: export/)')
+ help='Output directory (default: export)')
p.add_argument('--html', action='store_true',
help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights')
p.add_argument('--html-output', default=None, metavar='PATH',
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index cdda5a5..3f81247 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -405,9 +405,12 @@ def main():
parser = argparse.ArgumentParser(description="CNN v3 parity test vector generator")
parser.add_argument('--header', action='store_true',
help='Emit C header to stdout')
- parser.add_argument('--W', type=int, default=8)
- parser.add_argument('--H', type=int, default=8)
- parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--W', type=int, default=8,
+ help='Test image width (default 8)')
+ parser.add_argument('--H', type=int, default=8,
+ help='Test image height (default 8)')
+ parser.add_argument('--seed', type=int, default=42,
+ help='Random seed (default 42)')
args = parser.parse_args()
# Send self-test output to stderr so --header stdout stays clean
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index 5b6a0be..e48f684 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -21,6 +21,7 @@ Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB
Training improvements:
--edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
--film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
+ --checkpoint-every Save checkpoint every N epochs (default 100)
"""
import argparse
@@ -309,18 +310,22 @@ def main():
help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)')
# Model
- p.add_argument('--enc-channels', default='4,8',
- help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--enc-channels', default='8,16',
+ help='Encoder channels, comma-separated (default 8,16)')
p.add_argument('--film-cond-dim', type=int, default=5,
help='FiLM conditioning input dim (default 5)')
# Training
- p.add_argument('--epochs', type=int, default=200)
- p.add_argument('--batch-size', type=int, default=16)
- p.add_argument('--lr', type=float, default=1e-3)
- p.add_argument('--checkpoint-dir', default='checkpoints')
- p.add_argument('--checkpoint-every', type=int, default=50,
- help='Save checkpoint every N epochs (0=disable)')
+ p.add_argument('--epochs', type=int, default=200,
+ help='Total training epochs (default 200)')
+ p.add_argument('--batch-size', type=int, default=16,
+ help='Batch size (default 16)')
+ p.add_argument('--lr', type=float, default=1e-3,
+ help='Learning rate (default 1e-3)')
+ p.add_argument('--checkpoint-dir', default='checkpoints',
+ help='Directory to save checkpoints (default checkpoints)')
+ p.add_argument('--checkpoint-every', type=int, default=100,
+ help='Save checkpoint every N epochs (default 100; 0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
p.add_argument('--edge-loss-weight', type=float, default=0.1,