diff options
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 88 |
1 files changed, 72 insertions, 16 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index b32e548..5a3d56c 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -22,6 +22,13 @@ Sample directory layout (per sample_xxx/): shadow.png R uint8 [0=dark, 255=lit] transp.png R uint8 [0=opaque, 255=clear] target.png RGBA uint8 + +Patch alignment (patch_search_window > 0): + Source (albedo) and target images may not be perfectly co-registered. + When patch_search_window=N, each target patch centre is shifted by the + (dx, dy) in [-N, N]² that minimises grayscale MSE against the source + albedo patch. The search runs once at dataset init and results are + cached, so __getitem__ pays only a list-lookup per sample. """ import random @@ -44,6 +51,8 @@ GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb +_LUMA = np.array([0.2126, 0.7152, 0.0722], dtype=np.float32) # BT.709 + # --------------------------------------------------------------------------- # Image I/O # --------------------------------------------------------------------------- @@ -203,6 +212,34 @@ def detect_salient_points(albedo: np.ndarray, n: int, detector: str, # Dataset # --------------------------------------------------------------------------- +def _find_target_offsets(albedo: np.ndarray, target: np.ndarray, + centers: List[Tuple[int, int]], + patch_size: int, window: int) -> List[Tuple[int, int]]: + """For each source centre, find the (dx, dy) offset in target that minimises + grayscale MSE between the source albedo patch and the target patch.""" + h, w = albedo.shape[:2] + half = patch_size // 2 + offsets = [] + for cx, cy in centers: + cx = max(half, min(cx, w - half)) + cy = max(half, min(cy, h - half)) + src_gray = (albedo[cy - half:cy - half + patch_size, + cx - half:cx - half + patch_size, :3] @ _LUMA) + best_dx, best_dy, best_mse = 0, 0, float('inf') + for dy in range(-window, window + 1): + for dx in range(-window, window + 1): + tcx = max(half, min(cx + dx, w - half)) + tcy = max(half, min(cy + dy, h - half)) + tgt_gray = (target[tcy - half:tcy - half + patch_size, + tcx - half:tcx - half + patch_size, :3] @ _LUMA) + mse = np.mean((src_gray - tgt_gray) ** 2) + if mse < best_mse: + best_mse = mse + best_dx, best_dy = dx, dy + offsets.append((best_dx, best_dy)) + return offsets + + class CNNv3Dataset(Dataset): """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories. @@ -211,6 +248,9 @@ class CNNv3Dataset(Dataset): Full-image mode (--full-image): resizes entire image to image_size×image_size. + patch_search_window: when >0, the target patch is offset by up to this many + pixels (full-pixel search) to minimise grayscale MSE against the source patch. + Returns (feat, cond, target): feat: (20, H, W) f32 cond: (5,) f32 FiLM conditioning (random when augment=True) @@ -225,14 +265,16 @@ class CNNv3Dataset(Dataset): full_image: bool = False, channel_dropout_p: float = 0.3, detector: str = 'harris', - augment: bool = True): - self.patch_size = patch_size - self.patches_per_image = patches_per_image - self.image_size = image_size - self.full_image = full_image - self.channel_dropout_p = channel_dropout_p - self.detector = detector - self.augment = augment + augment: bool = True, + patch_search_window: int = 0): + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.image_size = image_size + self.full_image = full_image + self.channel_dropout_p = channel_dropout_p + self.detector = detector + self.augment = augment + self.patch_search_window = patch_search_window root = Path(dataset_dir) subdir = 'full' if input_mode == 'full' else 'simple' @@ -251,14 +293,21 @@ class CNNv3Dataset(Dataset): print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …") self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples] - # Pre-cache salient patch centres (albedo already loaded above) + # Pre-cache salient patch centres and (optionally) target offsets. self._patch_centers: List[List[Tuple[int, int]]] = [] + self._target_offsets: List[List[Tuple[int, int]]] = [] # (dx, dy) per patch if not full_image: print(f"[CNNv3Dataset] Detecting salient points " f"(detector={detector}, patch={patch_size}×{patch_size}) …") - for sd, (albedo, *_) in zip(self.samples, self._cache): + for albedo, *rest, target in self._cache: pts = detect_salient_points(albedo, patches_per_image, detector, patch_size) self._patch_centers.append(pts) + if patch_search_window > 0: + self._target_offsets.append( + _find_target_offsets(albedo, target, pts, patch_size, patch_search_window)) + if patch_search_window > 0: + print(f"[CNNv3Dataset] Target offset search done " + f"(window=±{patch_search_window})") print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " f"patch={patch_size} full_image={full_image}") @@ -285,10 +334,8 @@ class CNNv3Dataset(Dataset): def __getitem__(self, idx): if self.full_image: sample_idx = idx - sd = self.samples[idx] else: sample_idx = idx // self.patches_per_image - sd = self.samples[sample_idx] albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] h, w = albedo.shape[:2] @@ -314,9 +361,10 @@ class CNNv3Dataset(Dataset): transp = _resize_gray(transp) target = _resize_img(target) else: - ps = self.patch_size - half = ps // 2 - cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image] + ps = self.patch_size + half = ps // 2 + patch_idx = idx % self.patches_per_image + cx, cy = self._patch_centers[sample_idx][patch_idx] cx = max(half, min(cx, w - half)) cy = max(half, min(cy, h - half)) sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps)) @@ -327,7 +375,15 @@ class CNNv3Dataset(Dataset): matid = matid[sl] shadow = shadow[sl] transp = transp[sl] - target = target[sl] + + # Apply cached target offset (if search was enabled at init). + if self._target_offsets: + dx, dy = self._target_offsets[sample_idx][patch_idx] + tcx = max(half, min(cx + dx, w - half)) + tcy = max(half, min(cy + dy, h - half)) + target = target[tcy - half:tcy - half + ps, tcx - half:tcx - half + ps] + else: + target = target[sl] feat = assemble_features(albedo, normal, depth, matid, shadow, transp) |
