summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 22:08:19 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 22:08:19 +0100
commit4ca498277b033ae10134045dae9c8c249a8d2b2b (patch)
tree0b834c3f5cacc6cbf7cdab37c4d41ac217fb0d9a
parent4ad0e121108261884cdf49374481e04095a6d9c7 (diff)
fix(cnn_v3/training): rebuild optimizer before loading state on resume past FiLM warmupHEADmain
When resuming a checkpoint saved after the FiLM warmup phase, the optimizer was created with frozen (fewer) param groups, causing a size mismatch when loading the saved optimizer state. Fix: detect ckpt['epoch'] >= film_warmup, unfreeze FiLM MLP, and rebuild the optimizer before loading its state dict. handoff(Gemini): train_cnn_v3.py --resume now works past epoch 1500.
-rw-r--r--cnn_v3/training/train_cnn_v3.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index c790495..c61c360 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -178,8 +178,16 @@ def train(args):
print(f"Resuming from {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
- optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
+ # If checkpoint was saved after FiLM warmup, unfreeze and rebuild optimizer
+ # to match the param groups that were active when the checkpoint was saved.
+ if not film_unfrozen and ckpt['epoch'] >= film_warmup:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = True
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ film_unfrozen = True
+ print(f" FiLM MLP unfrozen (checkpoint past warmup epoch {film_warmup})")
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})")
print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}")