diff options
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c790495..c61c360 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -178,8 +178,16 @@ def train(args): print(f"Resuming from {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) - optimizer.load_state_dict(ckpt['optimizer_state_dict']) start_epoch = ckpt['epoch'] + 1 + # If checkpoint was saved after FiLM warmup, unfreeze and rebuild optimizer + # to match the param groups that were active when the checkpoint was saved. + if not film_unfrozen and ckpt['epoch'] >= film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + film_unfrozen = True + print(f" FiLM MLP unfrozen (checkpoint past warmup epoch {film_warmup})") + optimizer.load_state_dict(ckpt['optimizer_state_dict']) print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") |
