summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-22 12:17:30 +0100
commitfbc7cfdbcf4e33453b9ed4706f9d30190b1225f4 (patch)
tree767b854da2d3171505db52211d3259afdea05573 /cnn_v3/training/cnn_v3_utils.py
parent24397204670dff183df2c4b56fa3fcdf87411f08 (diff)
feat(cnn_v3): patch alignment search, resume, Ctrl-C save
- --patch-search-window N: at dataset init, find per-patch (dx,dy) in [-N,N]² that minimises grayscale MSE between source albedo and target; result cached so __getitem__ pays only a list-lookup per sample. - --resume [CKPT]: restore model + Adam state from a checkpoint; omit path to auto-select the latest in --checkpoint-dir. - Ctrl-C (SIGINT) finishes the current batch, then saves a checkpoint before exiting; finally-block guarded so no spurious epoch-0 save. - Review: remove unused sd variable, lift patch_idx out of duplicate computation, move _LUMA to Constants block, update module docstring. handoff(Gemini): cnn_v3/training updated — no C++ or test changes.
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py88
1 files changed, 72 insertions, 16 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index b32e548..5a3d56c 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -22,6 +22,13 @@ Sample directory layout (per sample_xxx/):
shadow.png R uint8 [0=dark, 255=lit]
transp.png R uint8 [0=opaque, 255=clear]
target.png RGBA uint8
+
+Patch alignment (patch_search_window > 0):
+ Source (albedo) and target images may not be perfectly co-registered.
+ When patch_search_window=N, each target patch centre is shifted by the
+ (dx, dy) in [-N, N]² that minimises grayscale MSE against the source
+ albedo patch. The search runs once at dataset init and results are
+ cached, so __getitem__ pays only a list-lookup per sample.
"""
import random
@@ -44,6 +51,8 @@ GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy
CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp
TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb
+_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709
+
# ---------------------------------------------------------------------------
# Image I/O
# ---------------------------------------------------------------------------
@@ -203,6 +212,34 @@ def detect_salient_points(albedo: np.ndarray, n: int, detector: str,
# Dataset
# ---------------------------------------------------------------------------
+def _find_target_offsets(albedo: np.ndarray, target: np.ndarray,
+ centers: List[Tuple[int, int]],
+ patch_size: int, window: int) -> List[Tuple[int, int]]:
+ """For each source centre, find the (dx, dy) offset in target that minimises
+ grayscale MSE between the source albedo patch and the target patch."""
+ h, w = albedo.shape[:2]
+ half = patch_size // 2
+ offsets = []
+ for cx, cy in centers:
+ cx = max(half, min(cx, w - half))
+ cy = max(half, min(cy, h - half))
+ src_gray = (albedo[cy - half:cy - half + patch_size,
+ cx - half:cx - half + patch_size, :3] @ _LUMA)
+ best_dx, best_dy, best_mse = 0, 0, float('inf')
+ for dy in range(-window, window + 1):
+ for dx in range(-window, window + 1):
+ tcx = max(half, min(cx + dx, w - half))
+ tcy = max(half, min(cy + dy, h - half))
+ tgt_gray = (target[tcy - half:tcy - half + patch_size,
+ tcx - half:tcx - half + patch_size, :3] @ _LUMA)
+ mse = np.mean((src_gray - tgt_gray) ** 2)
+ if mse < best_mse:
+ best_mse = mse
+ best_dx, best_dy = dx, dy
+ offsets.append((best_dx, best_dy))
+ return offsets
+
+
class CNNv3Dataset(Dataset):
"""Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories.
@@ -211,6 +248,9 @@ class CNNv3Dataset(Dataset):
Full-image mode (--full-image): resizes entire image to image_size×image_size.
+ patch_search_window: when >0, the target patch is offset by up to this many
+ pixels (full-pixel search) to minimise grayscale MSE against the source patch.
+
Returns (feat, cond, target):
feat: (20, H, W) f32
cond: (5,) f32 FiLM conditioning (random when augment=True)
@@ -225,14 +265,16 @@ class CNNv3Dataset(Dataset):
full_image: bool = False,
channel_dropout_p: float = 0.3,
detector: str = 'harris',
- augment: bool = True):
- self.patch_size = patch_size
- self.patches_per_image = patches_per_image
- self.image_size = image_size
- self.full_image = full_image
- self.channel_dropout_p = channel_dropout_p
- self.detector = detector
- self.augment = augment
+ augment: bool = True,
+ patch_search_window: int = 0):
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.image_size = image_size
+ self.full_image = full_image
+ self.channel_dropout_p = channel_dropout_p
+ self.detector = detector
+ self.augment = augment
+ self.patch_search_window = patch_search_window
root = Path(dataset_dir)
subdir = 'full' if input_mode == 'full' else 'simple'
@@ -251,14 +293,21 @@ class CNNv3Dataset(Dataset):
print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …")
self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples]
- # Pre-cache salient patch centres (albedo already loaded above)
+ # Pre-cache salient patch centres and (optionally) target offsets.
self._patch_centers: List[List[Tuple[int, int]]] = []
+ self._target_offsets: List[List[Tuple[int, int]]] = [] # (dx, dy) per patch
if not full_image:
print(f"[CNNv3Dataset] Detecting salient points "
f"(detector={detector}, patch={patch_size}×{patch_size}) …")
- for sd, (albedo, *_) in zip(self.samples, self._cache):
+ for albedo, *rest, target in self._cache:
pts = detect_salient_points(albedo, patches_per_image, detector, patch_size)
self._patch_centers.append(pts)
+ if patch_search_window > 0:
+ self._target_offsets.append(
+ _find_target_offsets(albedo, target, pts, patch_size, patch_search_window))
+ if patch_search_window > 0:
+ print(f"[CNNv3Dataset] Target offset search done "
+ f"(window=±{patch_search_window})")
print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
f"patch={patch_size} full_image={full_image}")
@@ -285,10 +334,8 @@ class CNNv3Dataset(Dataset):
def __getitem__(self, idx):
if self.full_image:
sample_idx = idx
- sd = self.samples[idx]
else:
sample_idx = idx // self.patches_per_image
- sd = self.samples[sample_idx]
albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
@@ -314,9 +361,10 @@ class CNNv3Dataset(Dataset):
transp = _resize_gray(transp)
target = _resize_img(target)
else:
- ps = self.patch_size
- half = ps // 2
- cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image]
+ ps = self.patch_size
+ half = ps // 2
+ patch_idx = idx % self.patches_per_image
+ cx, cy = self._patch_centers[sample_idx][patch_idx]
cx = max(half, min(cx, w - half))
cy = max(half, min(cy, h - half))
sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps))
@@ -327,7 +375,15 @@ class CNNv3Dataset(Dataset):
matid = matid[sl]
shadow = shadow[sl]
transp = transp[sl]
- target = target[sl]
+
+ # Apply cached target offset (if search was enabled at init).
+ if self._target_offsets:
+ dx, dy = self._target_offsets[sample_idx][patch_idx]
+ tcx = max(half, min(cx + dx, w - half))
+ tcy = max(half, min(cy + dy, h - half))
+ target = target[tcy - half:tcy - half + ps, tcx - half:tcx - half + ps]
+ else:
+ target = target[sl]
feat = assemble_features(albedo, normal, depth, matid, shadow, transp)