summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/train_cnn_v3.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index c790495..c61c360 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -178,8 +178,16 @@ def train(args):
print(f"Resuming from {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
- optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
+ # If checkpoint was saved after FiLM warmup, unfreeze and rebuild optimizer
+ # to match the param groups that were active when the checkpoint was saved.
+ if not film_unfrozen and ckpt['epoch'] >= film_warmup:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = True
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ film_unfrozen = True
+ print(f" FiLM MLP unfrozen (checkpoint past warmup epoch {film_warmup})")
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})")
print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}")