summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
commitfbc7cfdbcf4e33453b9ed4706f9d30190b1225f4 (patch)
tree767b854da2d3171505db52211d3259afdea05573 /cnn_v3/training/train_cnn_v3.py
parent24397204670dff183df2c4b56fa3fcdf87411f08 (diff)
feat(cnn_v3): patch alignment search, resume, Ctrl-C save
- --patch-search-window N: at dataset init, find per-patch (dx,dy) in [-N,N]² that minimises grayscale MSE between source albedo and target; result cached so __getitem__ pays only a list-lookup per sample. - --resume [CKPT]: restore model + Adam state from a checkpoint; omit path to auto-select the latest in --checkpoint-dir. - Ctrl-C (SIGINT) finishes the current batch, then saves a checkpoint before exiting; finally-block guarded so no spurious epoch-0 save. - Review: remove unused sd variable, lift patch_idx out of duplicate computation, move _LUMA to Constants block, update module docstring. handoff(Gemini): cnn_v3/training updated — no C++ or test changes.
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py99
1 files changed, 70 insertions, 29 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index 083efb0..de10d6a 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -20,6 +20,7 @@ Weight budget: ~5.4 KB f16 (fits ≤6 KB target)
"""
import argparse
+import signal
import time
from pathlib import Path
@@ -113,6 +114,7 @@ def train(args):
channel_dropout_p=args.channel_dropout_p,
detector=args.detector,
augment=True,
+ patch_search_window=args.patch_search_window,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -122,44 +124,79 @@ def train(args):
print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- criterion = nn.MSELoss()
- ckpt_dir = Path(args.checkpoint_dir)
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+ ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
+ start_epoch = 1
- print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}")
+ if args.resume:
+ ckpt_path = Path(args.resume)
+ if not ckpt_path.exists():
+ # Auto-find latest checkpoint in ckpt_dir
+ ckpts = sorted(ckpt_dir.glob('checkpoint_epoch_*.pth'),
+ key=lambda p: int(p.stem.split('_')[-1]))
+ if not ckpts:
+ raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
+ ckpt_path = ckpts[-1]
+ print(f"Resuming from {ckpt_path}")
+ ckpt = torch.load(ckpt_path, map_location=device)
+ model.load_state_dict(ckpt['model_state_dict'])
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
+ start_epoch = ckpt['epoch'] + 1
+ print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})")
+
+ print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}")
start = time.time()
avg_loss = float('nan')
+ epoch = start_epoch - 1
+
+ interrupted = False
+
+ def _on_sigint(sig, frame):
+ nonlocal interrupted
+ interrupted = True
- for epoch in range(1, args.epochs + 1):
- model.train()
- epoch_loss = 0.0
- n_batches = 0
+ signal.signal(signal.SIGINT, _on_sigint)
- for feat, cond, target in loader:
- feat, cond, target = feat.to(device), cond.to(device), target.to(device)
- optimizer.zero_grad()
- loss = criterion(model(feat, cond), target)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- n_batches += 1
+ try:
+ for epoch in range(start_epoch, args.epochs + 1):
+ if interrupted:
+ break
+ model.train()
+ epoch_loss = 0.0
+ n_batches = 0
- avg_loss = epoch_loss / max(n_batches, 1)
- print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
- f"{time.time()-start:.0f}s", end='', flush=True)
+ for feat, cond, target in loader:
+ if interrupted:
+ break
+ feat, cond, target = feat.to(device), cond.to(device), target.to(device)
+ optimizer.zero_grad()
+ loss = criterion(model(feat, cond), target)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.item()
+ n_batches += 1
- if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
- print()
- ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
- torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
- print(f" → {ckpt}")
+ avg_loss = epoch_loss / max(n_batches, 1)
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
+ f"{time.time()-start:.0f}s", end='', flush=True)
- print()
- final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth"
- torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final)
- print(f"Final checkpoint: {final}")
- print(f"Done. {time.time()-start:.1f}s")
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print()
+ ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
+ print(f" → {ckpt}")
+ finally:
+ print()
+ if epoch >= start_epoch: # at least one epoch completed
+ final = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), final)
+ if interrupted:
+ print(f"Interrupted. Checkpoint saved: {final}")
+ else:
+ print(f"Final checkpoint: {final}")
+ print(f"Done. {time.time()-start:.1f}s")
return model
@@ -204,6 +241,8 @@ def main():
p.add_argument('--detector', default='harris',
choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'],
help='Salient point detector (default harris)')
+ p.add_argument('--patch-search-window', type=int, default=0,
+ help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)')
# Model
p.add_argument('--enc-channels', default='4,8',
@@ -218,6 +257,8 @@ def main():
p.add_argument('--checkpoint-dir', default='checkpoints')
p.add_argument('--checkpoint-every', type=int, default=50,
help='Save checkpoint every N epochs (0=disable)')
+ p.add_argument('--resume', default='', metavar='CKPT',
+ help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
train(p.parse_args())