summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:10:19 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:10:19 +0100
commit130d70ea130242533a0fd3a7ffeabdb68598c88e (patch)
treec5f89a8cf6f9695b1799bb41c18a4bab7a2e9ad2 /cnn_v3/training/train_cnn_v3.py
parent8f14bdd66cb002b2f89265b2a578ad93249089c9 (diff)
fix(cnn_v3/training): fix defaults and help strings across py tools
- train_cnn_v3.py: enc-channels 4,8→8,16; checkpoint-every 50→100; add help strings for epochs/batch-size/lr/checkpoint-dir - gen_test_vectors.py: add help strings for --W/--H/--seed - export_cnn_v3_weights.py: fix --output help string (export/→export)
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index 5b6a0be..e48f684 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -21,6 +21,7 @@ Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB
Training improvements:
--edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
--film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
+ --checkpoint-every Save checkpoint every N epochs (default 100)
"""
import argparse
@@ -309,18 +310,22 @@ def main():
help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)')
# Model
- p.add_argument('--enc-channels', default='4,8',
- help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--enc-channels', default='8,16',
+ help='Encoder channels, comma-separated (default 8,16)')
p.add_argument('--film-cond-dim', type=int, default=5,
help='FiLM conditioning input dim (default 5)')
# Training
- p.add_argument('--epochs', type=int, default=200)
- p.add_argument('--batch-size', type=int, default=16)
- p.add_argument('--lr', type=float, default=1e-3)
- p.add_argument('--checkpoint-dir', default='checkpoints')
- p.add_argument('--checkpoint-every', type=int, default=50,
- help='Save checkpoint every N epochs (0=disable)')
+ p.add_argument('--epochs', type=int, default=200,
+ help='Total training epochs (default 200)')
+ p.add_argument('--batch-size', type=int, default=16,
+ help='Batch size (default 16)')
+ p.add_argument('--lr', type=float, default=1e-3,
+ help='Learning rate (default 1e-3)')
+ p.add_argument('--checkpoint-dir', default='checkpoints',
+ help='Directory to save checkpoints (default checkpoints)')
+ p.add_argument('--checkpoint-every', type=int, default=100,
+ help='Save checkpoint every N epochs (default 100; 0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
p.add_argument('--edge-loss-weight', type=float, default=0.1,