diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 22:08:19 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 22:08:19 +0100 |
| commit | 4ca498277b033ae10134045dae9c8c249a8d2b2b (patch) | |
| tree | 0b834c3f5cacc6cbf7cdab37c4d41ac217fb0d9a /cnn_v3/training | |
| parent | 4ad0e121108261884cdf49374481e04095a6d9c7 (diff) | |
When resuming a checkpoint saved after the FiLM warmup phase, the optimizer
was created with frozen (fewer) param groups, causing a size mismatch when
loading the saved optimizer state. Fix: detect ckpt['epoch'] >= film_warmup,
unfreeze FiLM MLP, and rebuild the optimizer before loading its state dict.
handoff(Gemini): train_cnn_v3.py --resume now works past epoch 1500.
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c790495..c61c360 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -178,8 +178,16 @@ def train(args): print(f"Resuming from {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) - optimizer.load_state_dict(ckpt['optimizer_state_dict']) start_epoch = ckpt['epoch'] + 1 + # If checkpoint was saved after FiLM warmup, unfreeze and rebuild optimizer + # to match the param groups that were active when the checkpoint was saved. + if not film_unfrozen and ckpt['epoch'] >= film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + film_unfrozen = True + print(f" FiLM MLP unfrozen (checkpoint past warmup epoch {film_warmup})") + optimizer.load_state_dict(ckpt['optimizer_state_dict']) print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") |
