diff options
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 99 |
1 files changed, 70 insertions, 29 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 083efb0..de10d6a 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -20,6 +20,7 @@ Weight budget: ~5.4 KB f16 (fits ≤6 KB target) """ import argparse +import signal import time from pathlib import Path @@ -113,6 +114,7 @@ def train(args): channel_dropout_p=args.channel_dropout_p, detector=args.detector, augment=True, + patch_search_window=args.patch_search_window, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) @@ -122,44 +124,79 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - criterion = nn.MSELoss() - ckpt_dir = Path(args.checkpoint_dir) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) + start_epoch = 1 - print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}") + if args.resume: + ckpt_path = Path(args.resume) + if not ckpt_path.exists(): + # Auto-find latest checkpoint in ckpt_dir + ckpts = sorted(ckpt_dir.glob('checkpoint_epoch_*.pth'), + key=lambda p: int(p.stem.split('_')[-1])) + if not ckpts: + raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}") + ckpt_path = ckpts[-1] + print(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + model.load_state_dict(ckpt['model_state_dict']) + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + start_epoch = ckpt['epoch'] + 1 + print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") + + print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") start = time.time() avg_loss = float('nan') + epoch = start_epoch - 1 + + interrupted = False + + def _on_sigint(sig, frame): + nonlocal interrupted + interrupted = True - for epoch in range(1, args.epochs + 1): - model.train() - epoch_loss = 0.0 - n_batches = 0 + signal.signal(signal.SIGINT, _on_sigint) - for feat, cond, target in loader: - feat, cond, target = feat.to(device), cond.to(device), target.to(device) - optimizer.zero_grad() - loss = criterion(model(feat, cond), target) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - n_batches += 1 + try: + for epoch in range(start_epoch, args.epochs + 1): + if interrupted: + break + model.train() + epoch_loss = 0.0 + n_batches = 0 - avg_loss = epoch_loss / max(n_batches, 1) - print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " - f"{time.time()-start:.0f}s", end='', flush=True) + for feat, cond, target in loader: + if interrupted: + break + feat, cond, target = feat.to(device), cond.to(device), target.to(device) + optimizer.zero_grad() + loss = criterion(model(feat, cond), target) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + n_batches += 1 - if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: - print() - ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" - torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) - print(f" → {ckpt}") + avg_loss = epoch_loss / max(n_batches, 1) + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " + f"{time.time()-start:.0f}s", end='', flush=True) - print() - final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth" - torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final) - print(f"Final checkpoint: {final}") - print(f"Done. {time.time()-start:.1f}s") + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() + ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) + print(f" → {ckpt}") + finally: + print() + if epoch >= start_epoch: # at least one epoch completed + final = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), final) + if interrupted: + print(f"Interrupted. Checkpoint saved: {final}") + else: + print(f"Final checkpoint: {final}") + print(f"Done. {time.time()-start:.1f}s") return model @@ -204,6 +241,8 @@ def main(): p.add_argument('--detector', default='harris', choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'], help='Salient point detector (default harris)') + p.add_argument('--patch-search-window', type=int, default=0, + help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)') # Model p.add_argument('--enc-channels', default='4,8', @@ -218,6 +257,8 @@ def main(): p.add_argument('--checkpoint-dir', default='checkpoints') p.add_argument('--checkpoint-every', type=int, default=50, help='Save checkpoint every N epochs (0=disable)') + p.add_argument('--resume', default='', metavar='CKPT', + help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') train(p.parse_args()) |
