diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-22 12:17:30 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-22 12:17:30 +0100 |
| commit | fbc7cfdbcf4e33453b9ed4706f9d30190b1225f4 (patch) | |
| tree | 767b854da2d3171505db52211d3259afdea05573 | |
| parent | 24397204670dff183df2c4b56fa3fcdf87411f08 (diff) | |
feat(cnn_v3): patch alignment search, resume, Ctrl-C save
- --patch-search-window N: at dataset init, find per-patch (dx,dy) in
[-N,N]² that minimises grayscale MSE between source albedo and target;
result cached so __getitem__ pays only a list-lookup per sample.
- --resume [CKPT]: restore model + Adam state from a checkpoint; omit
path to auto-select the latest in --checkpoint-dir.
- Ctrl-C (SIGINT) finishes the current batch, then saves a checkpoint
before exiting; finally-block guarded so no spurious epoch-0 save.
- Review: remove unused sd variable, lift patch_idx out of duplicate
computation, move _LUMA to Constants block, update module docstring.
handoff(Gemini): cnn_v3/training updated — no C++ or test changes.
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 88 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 99 |
2 files changed, 142 insertions, 45 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index b32e548..5a3d56c 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -22,6 +22,13 @@ Sample directory layout (per sample_xxx/): shadow.png R uint8 [0=dark, 255=lit] transp.png R uint8 [0=opaque, 255=clear] target.png RGBA uint8 + +Patch alignment (patch_search_window > 0): + Source (albedo) and target images may not be perfectly co-registered. + When patch_search_window=N, each target patch centre is shifted by the + (dx, dy) in [-N, N]² that minimises grayscale MSE against the source + albedo patch. The search runs once at dataset init and results are + cached, so __getitem__ pays only a list-lookup per sample. """ import random @@ -44,6 +51,8 @@ GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb +_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709 + # --------------------------------------------------------------------------- # Image I/O # --------------------------------------------------------------------------- @@ -203,6 +212,34 @@ def detect_salient_points(albedo: np.ndarray, n: int, detector: str, # Dataset # --------------------------------------------------------------------------- +def _find_target_offsets(albedo: np.ndarray, target: np.ndarray, + centers: List[Tuple[int, int]], + patch_size: int, window: int) -> List[Tuple[int, int]]: + """For each source centre, find the (dx, dy) offset in target that minimises + grayscale MSE between the source albedo patch and the target patch.""" + h, w = albedo.shape[:2] + half = patch_size // 2 + offsets = [] + for cx, cy in centers: + cx = max(half, min(cx, w - half)) + cy = max(half, min(cy, h - half)) + src_gray = (albedo[cy - half:cy - half + patch_size, + cx - half:cx - half + patch_size, :3] @ _LUMA) + best_dx, best_dy, best_mse = 0, 0, float('inf') + for dy in range(-window, window + 1): + for dx in range(-window, window + 1): + tcx = max(half, min(cx + dx, w - half)) + tcy = max(half, min(cy + dy, h - half)) + tgt_gray = (target[tcy - half:tcy - half + patch_size, + tcx - half:tcx - half + patch_size, :3] @ _LUMA) + mse = np.mean((src_gray - tgt_gray) ** 2) + if mse < best_mse: + best_mse = mse + best_dx, best_dy = dx, dy + offsets.append((best_dx, best_dy)) + return offsets + + class CNNv3Dataset(Dataset): """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories. @@ -211,6 +248,9 @@ class CNNv3Dataset(Dataset): Full-image mode (--full-image): resizes entire image to image_size×image_size. + patch_search_window: when >0, the target patch is offset by up to this many + pixels (full-pixel search) to minimise grayscale MSE against the source patch. + Returns (feat, cond, target): feat: (20, H, W) f32 cond: (5,) f32 FiLM conditioning (random when augment=True) @@ -225,14 +265,16 @@ class CNNv3Dataset(Dataset): full_image: bool = False, channel_dropout_p: float = 0.3, detector: str = 'harris', - augment: bool = True): - self.patch_size = patch_size - self.patches_per_image = patches_per_image - self.image_size = image_size - self.full_image = full_image - self.channel_dropout_p = channel_dropout_p - self.detector = detector - self.augment = augment + augment: bool = True, + patch_search_window: int = 0): + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.image_size = image_size + self.full_image = full_image + self.channel_dropout_p = channel_dropout_p + self.detector = detector + self.augment = augment + self.patch_search_window = patch_search_window root = Path(dataset_dir) subdir = 'full' if input_mode == 'full' else 'simple' @@ -251,14 +293,21 @@ class CNNv3Dataset(Dataset): print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …") self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples] - # Pre-cache salient patch centres (albedo already loaded above) + # Pre-cache salient patch centres and (optionally) target offsets. self._patch_centers: List[List[Tuple[int, int]]] = [] + self._target_offsets: List[List[Tuple[int, int]]] = [] # (dx, dy) per patch if not full_image: print(f"[CNNv3Dataset] Detecting salient points " f"(detector={detector}, patch={patch_size}×{patch_size}) …") - for sd, (albedo, *_) in zip(self.samples, self._cache): + for albedo, *rest, target in self._cache: pts = detect_salient_points(albedo, patches_per_image, detector, patch_size) self._patch_centers.append(pts) + if patch_search_window > 0: + self._target_offsets.append( + _find_target_offsets(albedo, target, pts, patch_size, patch_search_window)) + if patch_search_window > 0: + print(f"[CNNv3Dataset] Target offset search done " + f"(window=±{patch_search_window})") print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " f"patch={patch_size} full_image={full_image}") @@ -285,10 +334,8 @@ class CNNv3Dataset(Dataset): def __getitem__(self, idx): if self.full_image: sample_idx = idx - sd = self.samples[idx] else: sample_idx = idx // self.patches_per_image - sd = self.samples[sample_idx] albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] h, w = albedo.shape[:2] @@ -314,9 +361,10 @@ class CNNv3Dataset(Dataset): transp = _resize_gray(transp) target = _resize_img(target) else: - ps = self.patch_size - half = ps // 2 - cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image] + ps = self.patch_size + half = ps // 2 + patch_idx = idx % self.patches_per_image + cx, cy = self._patch_centers[sample_idx][patch_idx] cx = max(half, min(cx, w - half)) cy = max(half, min(cy, h - half)) sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps)) @@ -327,7 +375,15 @@ class CNNv3Dataset(Dataset): matid = matid[sl] shadow = shadow[sl] transp = transp[sl] - target = target[sl] + + # Apply cached target offset (if search was enabled at init). + if self._target_offsets: + dx, dy = self._target_offsets[sample_idx][patch_idx] + tcx = max(half, min(cx + dx, w - half)) + tcy = max(half, min(cy + dy, h - half)) + target = target[tcy - half:tcy - half + ps, tcx - half:tcx - half + ps] + else: + target = target[sl] feat = assemble_features(albedo, normal, depth, matid, shadow, transp) diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 083efb0..de10d6a 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -20,6 +20,7 @@ Weight budget: ~5.4 KB f16 (fits ≤6 KB target) """ import argparse +import signal import time from pathlib import Path @@ -113,6 +114,7 @@ def train(args): channel_dropout_p=args.channel_dropout_p, detector=args.detector, augment=True, + patch_search_window=args.patch_search_window, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) @@ -122,44 +124,79 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - criterion = nn.MSELoss() - ckpt_dir = Path(args.checkpoint_dir) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) + start_epoch = 1 - print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}") + if args.resume: + ckpt_path = Path(args.resume) + if not ckpt_path.exists(): + # Auto-find latest checkpoint in ckpt_dir + ckpts = sorted(ckpt_dir.glob('checkpoint_epoch_*.pth'), + key=lambda p: int(p.stem.split('_')[-1])) + if not ckpts: + raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}") + ckpt_path = ckpts[-1] + print(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + model.load_state_dict(ckpt['model_state_dict']) + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + start_epoch = ckpt['epoch'] + 1 + print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") + + print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") start = time.time() avg_loss = float('nan') + epoch = start_epoch - 1 + + interrupted = False + + def _on_sigint(sig, frame): + nonlocal interrupted + interrupted = True - for epoch in range(1, args.epochs + 1): - model.train() - epoch_loss = 0.0 - n_batches = 0 + signal.signal(signal.SIGINT, _on_sigint) - for feat, cond, target in loader: - feat, cond, target = feat.to(device), cond.to(device), target.to(device) - optimizer.zero_grad() - loss = criterion(model(feat, cond), target) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - n_batches += 1 + try: + for epoch in range(start_epoch, args.epochs + 1): + if interrupted: + break + model.train() + epoch_loss = 0.0 + n_batches = 0 - avg_loss = epoch_loss / max(n_batches, 1) - print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " - f"{time.time()-start:.0f}s", end='', flush=True) + for feat, cond, target in loader: + if interrupted: + break + feat, cond, target = feat.to(device), cond.to(device), target.to(device) + optimizer.zero_grad() + loss = criterion(model(feat, cond), target) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + n_batches += 1 - if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: - print() - ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" - torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) - print(f" → {ckpt}") + avg_loss = epoch_loss / max(n_batches, 1) + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " + f"{time.time()-start:.0f}s", end='', flush=True) - print() - final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth" - torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final) - print(f"Final checkpoint: {final}") - print(f"Done. {time.time()-start:.1f}s") + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() + ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) + print(f" → {ckpt}") + finally: + print() + if epoch >= start_epoch: # at least one epoch completed + final = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), final) + if interrupted: + print(f"Interrupted. Checkpoint saved: {final}") + else: + print(f"Final checkpoint: {final}") + print(f"Done. {time.time()-start:.1f}s") return model @@ -204,6 +241,8 @@ def main(): p.add_argument('--detector', default='harris', choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'], help='Salient point detector (default harris)') + p.add_argument('--patch-search-window', type=int, default=0, + help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)') # Model p.add_argument('--enc-channels', default='4,8', @@ -218,6 +257,8 @@ def main(): p.add_argument('--checkpoint-dir', default='checkpoints') p.add_argument('--checkpoint-every', type=int, default=50, help='Save checkpoint every N epochs (0=disable)') + p.add_argument('--resume', default='', metavar='CKPT', + help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') train(p.parse_args()) |
