summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index 5b6a0be..e48f684 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -21,6 +21,7 @@ Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB
Training improvements:
--edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
--film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
+ --checkpoint-every Save checkpoint every N epochs (default 100)
"""
import argparse
@@ -309,18 +310,22 @@ def main():
help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)')
# Model
- p.add_argument('--enc-channels', default='4,8',
- help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--enc-channels', default='8,16',
+ help='Encoder channels, comma-separated (default 8,16)')
p.add_argument('--film-cond-dim', type=int, default=5,
help='FiLM conditioning input dim (default 5)')
# Training
- p.add_argument('--epochs', type=int, default=200)
- p.add_argument('--batch-size', type=int, default=16)
- p.add_argument('--lr', type=float, default=1e-3)
- p.add_argument('--checkpoint-dir', default='checkpoints')
- p.add_argument('--checkpoint-every', type=int, default=50,
- help='Save checkpoint every N epochs (0=disable)')
+ p.add_argument('--epochs', type=int, default=200,
+ help='Total training epochs (default 200)')
+ p.add_argument('--batch-size', type=int, default=16,
+ help='Batch size (default 16)')
+ p.add_argument('--lr', type=float, default=1e-3,
+ help='Learning rate (default 1e-3)')
+ p.add_argument('--checkpoint-dir', default='checkpoints',
+ help='Directory to save checkpoints (default checkpoints)')
+ p.add_argument('--checkpoint-every', type=int, default=100,
+ help='Save checkpoint every N epochs (default 100; 0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
p.add_argument('--edge-loss-weight', type=float, default=0.1,