"""CNN v3 training utilities — image I/O, feature assembly, dataset. Imported by train_cnn_v3.py and export_cnn_v3_weights.py. 20 feature channels assembled by CNNv3Dataset.__getitem__: [0-2] albedo.rgb f32 [0,1] [3-4] normal.xy f32 oct-encoded [0,1] [5] depth f32 [0,1] [6-7] depth_grad.xy finite diff, signed [8] mat_id f32 [0,1] [9-11] prev.rgb f32 (zero during training) [12-14] mip1.rgb pyrdown(albedo) [15-17] mip2.rgb pyrdown(mip1) [18] shadow f32 [0,1] [19] transp f32 [0,1] Sample directory layout (per sample_xxx/): albedo.png RGB uint8 normal.png RG uint8 oct-encoded (128,128 = no normal) depth.png R uint16 [0,65535] → [0,1] matid.png R uint8 [0,255] → [0,1] shadow.png R uint8 [0=dark, 255=lit] transp.png R uint8 [0=opaque, 255=clear] target.png RGBA uint8 """ import random from pathlib import Path from typing import List, Tuple import cv2 import numpy as np import torch from PIL import Image from torch.utils.data import Dataset # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- N_FEATURES = 20 GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb # --------------------------------------------------------------------------- # Image I/O # --------------------------------------------------------------------------- def load_rgb(path: Path) -> np.ndarray: """Load PNG → (H,W,3) f32 [0,1].""" return np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 def load_rg(path: Path) -> np.ndarray: """Load PNG RG channels → (H,W,2) f32 [0,1].""" arr = np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 return arr[..., :2] def load_depth16(path: Path) -> np.ndarray: """Load 16-bit greyscale depth PNG → (H,W) f32 [0,1].""" arr = np.asarray(Image.open(path), dtype=np.float32) if arr.max() > 1.0: arr = arr / 65535.0 return arr def load_gray(path: Path) -> np.ndarray: """Load 8-bit greyscale PNG → (H,W) f32 [0,1].""" return np.asarray(Image.open(path).convert('L'), dtype=np.float32) / 255.0 # --------------------------------------------------------------------------- # Feature assembly helpers # --------------------------------------------------------------------------- def pyrdown(img: np.ndarray) -> np.ndarray: """2×2 average pool → half resolution. img: (H,W,C) f32.""" h, w, _ = img.shape h2, w2 = h // 2, w // 2 t = img[:h2 * 2, :w2 * 2, :] return 0.25 * (t[0::2, 0::2] + t[1::2, 0::2] + t[0::2, 1::2] + t[1::2, 1::2]) def depth_gradient(depth: np.ndarray) -> np.ndarray: """Central finite difference of depth map → (H,W,2) [dzdx, dzdy].""" px = np.pad(depth, ((0, 0), (1, 1)), mode='edge') py = np.pad(depth, ((1, 1), (0, 0)), mode='edge') dzdx = (px[:, 2:] - px[:, :-2]) * 0.5 dzdy = (py[2:, :] - py[:-2, :]) * 0.5 return np.stack([dzdx, dzdy], axis=-1) def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: """Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss.""" sh, sw = a.shape[:2] ys = np.arange(h) * sh // h xs = np.arange(w) * sw // w return a[np.ix_(ys, xs)] def assemble_features(albedo: np.ndarray, normal: np.ndarray, depth: np.ndarray, matid: np.ndarray, shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: """Build (H,W,20) f32 feature tensor. prev set to zero (no temporal history during training). mip1/mip2 computed from albedo. depth_grad computed via finite diff. """ h, w = albedo.shape[:2] mip1 = _upsample_nearest(pyrdown(albedo), h, w) mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) dgrad = depth_gradient(depth) prev = np.zeros((h, w, 3), dtype=np.float32) return np.concatenate([ albedo, # [0-2] albedo.rgb normal, # [3-4] normal.xy depth[..., None], # [5] depth dgrad, # [6-7] depth_grad.xy matid[..., None], # [8] mat_id prev, # [9-11] prev.rgb mip1, # [12-14] mip1.rgb mip2, # [15-17] mip2.rgb shadow[..., None], # [18] shadow transp[..., None], # [19] transp ], axis=-1).astype(np.float32) def apply_channel_dropout(feat: np.ndarray, p_geom: float = 0.3, p_context: float = 0.2, p_temporal: float = 0.5) -> np.ndarray: """Zero out channel groups with given probabilities (in-place copy).""" feat = feat.copy() if random.random() < p_geom: feat[..., GEOMETRIC_CHANNELS] = 0.0 if random.random() < p_context: feat[..., CONTEXT_CHANNELS] = 0.0 if random.random() < p_temporal: feat[..., TEMPORAL_CHANNELS] = 0.0 return feat # --------------------------------------------------------------------------- # Salient point detection # --------------------------------------------------------------------------- def detect_salient_points(albedo: np.ndarray, n: int, detector: str, patch_size: int) -> List[Tuple[int, int]]: """Return n (cx, cy) patch centres from albedo (H,W,3) f32 [0,1]. Detects up to 2n candidates via the chosen method, filters to valid patch bounds, fills remainder with random points. detector: 'harris' | 'shi-tomasi' | 'fast' | 'gradient' | 'random' """ h, w = albedo.shape[:2] half = patch_size // 2 gray = cv2.cvtColor((albedo * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) corners = None if detector in ('harris', 'shi-tomasi'): corners = cv2.goodFeaturesToTrack( gray, n * 2, qualityLevel=0.01, minDistance=half, useHarrisDetector=(detector == 'harris')) elif detector == 'fast': kps = cv2.FastFeatureDetector_create(threshold=20).detect(gray, None) if kps: pts = np.array([[kp.pt[0], kp.pt[1]] for kp in kps[:n * 2]]) corners = pts.reshape(-1, 1, 2) elif detector == 'gradient': gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) mag = np.sqrt(gx ** 2 + gy ** 2) ys, xs = np.where(mag > np.percentile(mag, 95)) if len(xs) > n * 2: sel = np.random.choice(len(xs), n * 2, replace=False) xs, ys = xs[sel], ys[sel] if len(xs): corners = np.stack([xs, ys], axis=1).reshape(-1, 1, 2).astype(np.float32) # 'random' → corners stays None, falls through to random fill pts: List[Tuple[int, int]] = [] if corners is not None: for c in corners: cx, cy = int(c[0][0]), int(c[0][1]) if half <= cx < w - half and half <= cy < h - half: pts.append((cx, cy)) if len(pts) >= n: break while len(pts) < n: pts.append((random.randint(half, w - half - 1), random.randint(half, h - half - 1))) return pts[:n] # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class CNNv3Dataset(Dataset): """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories. Patch mode (default): extracts patch_size×patch_size crops centred on salient points detected from albedo. Points are pre-cached at init. Full-image mode (--full-image): resizes entire image to image_size×image_size. Returns (feat, cond, target): feat: (20, H, W) f32 cond: (5,) f32 FiLM conditioning (random when augment=True) target: (4, H, W) f32 RGBA [0,1] """ def __init__(self, dataset_dir: str, input_mode: str = 'simple', patch_size: int = 64, patches_per_image: int = 256, image_size: int = 256, full_image: bool = False, channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size self.full_image = full_image self.channel_dropout_p = channel_dropout_p self.detector = detector self.augment = augment root = Path(dataset_dir) subdir = 'full' if input_mode == 'full' else 'simple' search_dir = root / subdir if not search_dir.exists(): search_dir = root self.samples = sorted([ d for d in search_dir.iterdir() if d.is_dir() and (d / 'albedo.png').exists() ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") # Pre-load all sample data into memory print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …") self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples] # Pre-cache salient patch centres (albedo already loaded above) self._patch_centers: List[List[Tuple[int, int]]] = [] if not full_image: print(f"[CNNv3Dataset] Detecting salient points " f"(detector={detector}, patch={patch_size}×{patch_size}) …") for sd, (albedo, *_) in zip(self.samples, self._cache): pts = detect_salient_points(albedo, patches_per_image, detector, patch_size) self._patch_centers.append(pts) print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " f"patch={patch_size} full_image={full_image}") def __len__(self): if self.full_image: return len(self.samples) return len(self.samples) * self.patches_per_image def _load_sample(self, sd: Path): albedo = load_rgb(sd / 'albedo.png') normal = load_rg(sd / 'normal.png') depth = load_depth16(sd / 'depth.png') matid = load_gray(sd / 'matid.png') shadow = load_gray(sd / 'shadow.png') transp = load_gray(sd / 'transp.png') h, w = albedo.shape[:2] target_img = Image.open(sd / 'target.png').convert('RGBA') if target_img.size != (w, h): target_img = target_img.resize((w, h), Image.LANCZOS) target = np.asarray(target_img, dtype=np.float32) / 255.0 return albedo, normal, depth, matid, shadow, transp, target def __getitem__(self, idx): if self.full_image: sample_idx = idx sd = self.samples[idx] else: sample_idx = idx // self.patches_per_image sd = self.samples[sample_idx] albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: sz = self.image_size def _resize_img(a): # PIL handles RGB, RGBA, and grayscale by channel count img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 def _resize_gray(a): img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L') return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 albedo = _resize_img(albedo) normal = _resize_img(np.concatenate( [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] # pad to 3ch for PIL depth = _resize_gray(depth) matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) target = _resize_img(target) else: ps = self.patch_size half = ps // 2 cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image] cx = max(half, min(cx, w - half)) cy = max(half, min(cy, h - half)) sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps)) albedo = albedo[sl] normal = normal[sl] depth = depth[sl] matid = matid[sl] shadow = shadow[sl] transp = transp[sl] target = target[sl] feat = assemble_features(albedo, normal, depth, matid, shadow, transp) if self.augment: feat = apply_channel_dropout(feat, p_geom=self.channel_dropout_p, p_context=self.channel_dropout_p * 0.67, p_temporal=0.5) cond = np.random.rand(5).astype(np.float32) else: cond = np.zeros(5, dtype=np.float32) return (torch.from_numpy(feat).permute(2, 0, 1), # (20,H,W) torch.from_numpy(cond), # (5,) torch.from_numpy(target).permute(2, 0, 1)) # (4,H,W)