From 130d70ea130242533a0fd3a7ffeabdb68598c88e Mon Sep 17 00:00:00 2001 From: skal Date: Thu, 26 Mar 2026 07:10:19 +0100 Subject: fix(cnn_v3/training): fix defaults and help strings across py tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - train_cnn_v3.py: enc-channels 4,8→8,16; checkpoint-every 50→100; add help strings for epochs/batch-size/lr/checkpoint-dir - gen_test_vectors.py: add help strings for --W/--H/--seed - export_cnn_v3_weights.py: fix --output help string (export/→export) --- cnn_v3/training/export_cnn_v3_weights.py | 2 +- cnn_v3/training/gen_test_vectors.py | 9 ++++++--- cnn_v3/training/train_cnn_v3.py | 21 +++++++++++++-------- 3 files changed, 20 insertions(+), 12 deletions(-) (limited to 'cnn_v3') diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index 2fa83d1..67330a6 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -189,7 +189,7 @@ def main() -> None: p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') p.add_argument('checkpoint', help='Path to .pth checkpoint file') p.add_argument('--output', default='export', - help='Output directory (default: export/)') + help='Output directory (default: export)') p.add_argument('--html', action='store_true', help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights') p.add_argument('--html-output', default=None, metavar='PATH', diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index cdda5a5..3f81247 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -405,9 +405,12 @@ def main(): parser = argparse.ArgumentParser(description="CNN v3 parity test vector generator") parser.add_argument('--header', action='store_true', help='Emit C header to stdout') - parser.add_argument('--W', type=int, default=8) - parser.add_argument('--H', type=int, default=8) - parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--W', type=int, default=8, + help='Test image width (default 8)') + parser.add_argument('--H', type=int, default=8, + help='Test image height (default 8)') + parser.add_argument('--seed', type=int, default=42, + help='Random seed (default 42)') args = parser.parse_args() # Send self-test output to stderr so --header stdout stays clean diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 5b6a0be..e48f684 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -21,6 +21,7 @@ Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB Training improvements: --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) + --checkpoint-every Save checkpoint every N epochs (default 100) """ import argparse @@ -309,18 +310,22 @@ def main(): help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)') # Model - p.add_argument('--enc-channels', default='4,8', - help='Encoder channels, comma-separated (default 4,8)') + p.add_argument('--enc-channels', default='8,16', + help='Encoder channels, comma-separated (default 8,16)') p.add_argument('--film-cond-dim', type=int, default=5, help='FiLM conditioning input dim (default 5)') # Training - p.add_argument('--epochs', type=int, default=200) - p.add_argument('--batch-size', type=int, default=16) - p.add_argument('--lr', type=float, default=1e-3) - p.add_argument('--checkpoint-dir', default='checkpoints') - p.add_argument('--checkpoint-every', type=int, default=50, - help='Save checkpoint every N epochs (0=disable)') + p.add_argument('--epochs', type=int, default=200, + help='Total training epochs (default 200)') + p.add_argument('--batch-size', type=int, default=16, + help='Batch size (default 16)') + p.add_argument('--lr', type=float, default=1e-3, + help='Learning rate (default 1e-3)') + p.add_argument('--checkpoint-dir', default='checkpoints', + help='Directory to save checkpoints (default checkpoints)') + p.add_argument('--checkpoint-every', type=int, default=100, + help='Save checkpoint every N epochs (default 100; 0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') p.add_argument('--edge-loss-weight', type=float, default=0.1, -- cgit v1.2.3