From 4ca498277b033ae10134045dae9c8c249a8d2b2b Mon Sep 17 00:00:00 2001 From: skal Date: Wed, 25 Mar 2026 22:08:19 +0100 Subject: fix(cnn_v3/training): rebuild optimizer before loading state on resume past FiLM warmup When resuming a checkpoint saved after the FiLM warmup phase, the optimizer was created with frozen (fewer) param groups, causing a size mismatch when loading the saved optimizer state. Fix: detect ckpt['epoch'] >= film_warmup, unfreeze FiLM MLP, and rebuild the optimizer before loading its state dict. handoff(Gemini): train_cnn_v3.py --resume now works past epoch 1500. --- cnn_v3/training/train_cnn_v3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'cnn_v3') diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c790495..c61c360 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -178,8 +178,16 @@ def train(args): print(f"Resuming from {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) - optimizer.load_state_dict(ckpt['optimizer_state_dict']) start_epoch = ckpt['epoch'] + 1 + # If checkpoint was saved after FiLM warmup, unfreeze and rebuild optimizer + # to match the param groups that were active when the checkpoint was saved. + if not film_unfrozen and ckpt['epoch'] >= film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + film_unfrozen = True + print(f" FiLM MLP unfrozen (checkpoint past warmup epoch {film_warmup})") + optimizer.load_state_dict(ckpt['optimizer_state_dict']) print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") -- cgit v1.2.3