summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
commitfbc7cfdbcf4e33453b9ed4706f9d30190b1225f4 (patch)
tree767b854da2d3171505db52211d3259afdea05573 /cnn_v3/training
parent24397204670dff183df2c4b56fa3fcdf87411f08 (diff)
feat(cnn_v3): patch alignment search, resume, Ctrl-C save
- --patch-search-window N: at dataset init, find per-patch (dx,dy) in [-N,N]² that minimises grayscale MSE between source albedo and target; result cached so __getitem__ pays only a list-lookup per sample. - --resume [CKPT]: restore model + Adam state from a checkpoint; omit path to auto-select the latest in --checkpoint-dir. - Ctrl-C (SIGINT) finishes the current batch, then saves a checkpoint before exiting; finally-block guarded so no spurious epoch-0 save. - Review: remove unused sd variable, lift patch_idx out of duplicate computation, move _LUMA to Constants block, update module docstring. handoff(Gemini): cnn_v3/training updated — no C++ or test changes.
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py88
-rw-r--r--cnn_v3/training/train_cnn_v3.py99
2 files changed, 142 insertions, 45 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index b32e548..5a3d56c 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -22,6 +22,13 @@ Sample directory layout (per sample_xxx/):
shadow.png R uint8 [0=dark, 255=lit]
transp.png R uint8 [0=opaque, 255=clear]
target.png RGBA uint8
+
+Patch alignment (patch_search_window > 0):
+ Source (albedo) and target images may not be perfectly co-registered.
+ When patch_search_window=N, each target patch centre is shifted by the
+ (dx, dy) in [-N, N]² that minimises grayscale MSE against the source
+ albedo patch. The search runs once at dataset init and results are
+ cached, so __getitem__ pays only a list-lookup per sample.
"""
import random
@@ -44,6 +51,8 @@ GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy
CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp
TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb
+_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709
+
# ---------------------------------------------------------------------------
# Image I/O
# ---------------------------------------------------------------------------
@@ -203,6 +212,34 @@ def detect_salient_points(albedo: np.ndarray, n: int, detector: str,
# Dataset
# ---------------------------------------------------------------------------
+def _find_target_offsets(albedo: np.ndarray, target: np.ndarray,
+ centers: List[Tuple[int, int]],
+ patch_size: int, window: int) -> List[Tuple[int, int]]:
+ """For each source centre, find the (dx, dy) offset in target that minimises
+ grayscale MSE between the source albedo patch and the target patch."""
+ h, w = albedo.shape[:2]
+ half = patch_size // 2
+ offsets = []
+ for cx, cy in centers:
+ cx = max(half, min(cx, w - half))
+ cy = max(half, min(cy, h - half))
+ src_gray = (albedo[cy - half:cy - half + patch_size,
+ cx - half:cx - half + patch_size, :3] @ _LUMA)
+ best_dx, best_dy, best_mse = 0, 0, float('inf')
+ for dy in range(-window, window + 1):
+ for dx in range(-window, window + 1):
+ tcx = max(half, min(cx + dx, w - half))
+ tcy = max(half, min(cy + dy, h - half))
+ tgt_gray = (target[tcy - half:tcy - half + patch_size,
+ tcx - half:tcx - half + patch_size, :3] @ _LUMA)
+ mse = np.mean((src_gray - tgt_gray) ** 2)
+ if mse < best_mse:
+ best_mse = mse
+ best_dx, best_dy = dx, dy
+ offsets.append((best_dx, best_dy))
+ return offsets
+
+
class CNNv3Dataset(Dataset):
"""Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories.
@@ -211,6 +248,9 @@ class CNNv3Dataset(Dataset):
Full-image mode (--full-image): resizes entire image to image_size×image_size.
+ patch_search_window: when >0, the target patch is offset by up to this many
+ pixels (full-pixel search) to minimise grayscale MSE against the source patch.
+
Returns (feat, cond, target):
feat: (20, H, W) f32
cond: (5,) f32 FiLM conditioning (random when augment=True)
@@ -225,14 +265,16 @@ class CNNv3Dataset(Dataset):
full_image: bool = False,
channel_dropout_p: float = 0.3,
detector: str = 'harris',
- augment: bool = True):
- self.patch_size = patch_size
- self.patches_per_image = patches_per_image
- self.image_size = image_size
- self.full_image = full_image
- self.channel_dropout_p = channel_dropout_p
- self.detector = detector
- self.augment = augment
+ augment: bool = True,
+ patch_search_window: int = 0):
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.image_size = image_size
+ self.full_image = full_image
+ self.channel_dropout_p = channel_dropout_p
+ self.detector = detector
+ self.augment = augment
+ self.patch_search_window = patch_search_window
root = Path(dataset_dir)
subdir = 'full' if input_mode == 'full' else 'simple'
@@ -251,14 +293,21 @@ class CNNv3Dataset(Dataset):
print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …")
self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples]
- # Pre-cache salient patch centres (albedo already loaded above)
+ # Pre-cache salient patch centres and (optionally) target offsets.
self._patch_centers: List[List[Tuple[int, int]]] = []
+ self._target_offsets: List[List[Tuple[int, int]]] = [] # (dx, dy) per patch
if not full_image:
print(f"[CNNv3Dataset] Detecting salient points "
f"(detector={detector}, patch={patch_size}×{patch_size}) …")
- for sd, (albedo, *_) in zip(self.samples, self._cache):
+ for albedo, *rest, target in self._cache:
pts = detect_salient_points(albedo, patches_per_image, detector, patch_size)
self._patch_centers.append(pts)
+ if patch_search_window > 0:
+ self._target_offsets.append(
+ _find_target_offsets(albedo, target, pts, patch_size, patch_search_window))
+ if patch_search_window > 0:
+ print(f"[CNNv3Dataset] Target offset search done "
+ f"(window=±{patch_search_window})")
print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
f"patch={patch_size} full_image={full_image}")
@@ -285,10 +334,8 @@ class CNNv3Dataset(Dataset):
def __getitem__(self, idx):
if self.full_image:
sample_idx = idx
- sd = self.samples[idx]
else:
sample_idx = idx // self.patches_per_image
- sd = self.samples[sample_idx]
albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
@@ -314,9 +361,10 @@ class CNNv3Dataset(Dataset):
transp = _resize_gray(transp)
target = _resize_img(target)
else:
- ps = self.patch_size
- half = ps // 2
- cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image]
+ ps = self.patch_size
+ half = ps // 2
+ patch_idx = idx % self.patches_per_image
+ cx, cy = self._patch_centers[sample_idx][patch_idx]
cx = max(half, min(cx, w - half))
cy = max(half, min(cy, h - half))
sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps))
@@ -327,7 +375,15 @@ class CNNv3Dataset(Dataset):
matid = matid[sl]
shadow = shadow[sl]
transp = transp[sl]
- target = target[sl]
+
+ # Apply cached target offset (if search was enabled at init).
+ if self._target_offsets:
+ dx, dy = self._target_offsets[sample_idx][patch_idx]
+ tcx = max(half, min(cx + dx, w - half))
+ tcy = max(half, min(cy + dy, h - half))
+ target = target[tcy - half:tcy - half + ps, tcx - half:tcx - half + ps]
+ else:
+ target = target[sl]
feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index 083efb0..de10d6a 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -20,6 +20,7 @@ Weight budget: ~5.4 KB f16 (fits ≤6 KB target)
"""
import argparse
+import signal
import time
from pathlib import Path
@@ -113,6 +114,7 @@ def train(args):
channel_dropout_p=args.channel_dropout_p,
detector=args.detector,
augment=True,
+ patch_search_window=args.patch_search_window,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -122,44 +124,79 @@ def train(args):
print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- criterion = nn.MSELoss()
- ckpt_dir = Path(args.checkpoint_dir)
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+ ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
+ start_epoch = 1
- print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}")
+ if args.resume:
+ ckpt_path = Path(args.resume)
+ if not ckpt_path.exists():
+ # Auto-find latest checkpoint in ckpt_dir
+ ckpts = sorted(ckpt_dir.glob('checkpoint_epoch_*.pth'),
+ key=lambda p: int(p.stem.split('_')[-1]))
+ if not ckpts:
+ raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
+ ckpt_path = ckpts[-1]
+ print(f"Resuming from {ckpt_path}")
+ ckpt = torch.load(ckpt_path, map_location=device)
+ model.load_state_dict(ckpt['model_state_dict'])
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
+ start_epoch = ckpt['epoch'] + 1
+ print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})")
+
+ print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}")
start = time.time()
avg_loss = float('nan')
+ epoch = start_epoch - 1
+
+ interrupted = False
+
+ def _on_sigint(sig, frame):
+ nonlocal interrupted
+ interrupted = True
- for epoch in range(1, args.epochs + 1):
- model.train()
- epoch_loss = 0.0
- n_batches = 0
+ signal.signal(signal.SIGINT, _on_sigint)
- for feat, cond, target in loader:
- feat, cond, target = feat.to(device), cond.to(device), target.to(device)
- optimizer.zero_grad()
- loss = criterion(model(feat, cond), target)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- n_batches += 1
+ try:
+ for epoch in range(start_epoch, args.epochs + 1):
+ if interrupted:
+ break
+ model.train()
+ epoch_loss = 0.0
+ n_batches = 0
- avg_loss = epoch_loss / max(n_batches, 1)
- print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
- f"{time.time()-start:.0f}s", end='', flush=True)
+ for feat, cond, target in loader:
+ if interrupted:
+ break
+ feat, cond, target = feat.to(device), cond.to(device), target.to(device)
+ optimizer.zero_grad()
+ loss = criterion(model(feat, cond), target)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.item()
+ n_batches += 1
- if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
- print()
- ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
- torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
- print(f" → {ckpt}")
+ avg_loss = epoch_loss / max(n_batches, 1)
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
+ f"{time.time()-start:.0f}s", end='', flush=True)
- print()
- final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth"
- torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final)
- print(f"Final checkpoint: {final}")
- print(f"Done. {time.time()-start:.1f}s")
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print()
+ ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
+ print(f" → {ckpt}")
+ finally:
+ print()
+ if epoch >= start_epoch: # at least one epoch completed
+ final = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), final)
+ if interrupted:
+ print(f"Interrupted. Checkpoint saved: {final}")
+ else:
+ print(f"Final checkpoint: {final}")
+ print(f"Done. {time.time()-start:.1f}s")
return model
@@ -204,6 +241,8 @@ def main():
p.add_argument('--detector', default='harris',
choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'],
help='Salient point detector (default harris)')
+ p.add_argument('--patch-search-window', type=int, default=0,
+ help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)')
# Model
p.add_argument('--enc-channels', default='4,8',
@@ -218,6 +257,8 @@ def main():
p.add_argument('--checkpoint-dir', default='checkpoints')
p.add_argument('--checkpoint-every', type=int, default=50,
help='Save checkpoint every N epochs (0=disable)')
+ p.add_argument('--resume', default='', metavar='CKPT',
+ help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
train(p.parse_args())