diff options
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 339 |
1 files changed, 339 insertions, 0 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py new file mode 100644 index 0000000..ecdbd6b --- /dev/null +++ b/cnn_v3/training/cnn_v3_utils.py @@ -0,0 +1,339 @@ +"""CNN v3 training utilities — image I/O, feature assembly, dataset. + +Imported by train_cnn_v3.py and export_cnn_v3_weights.py. + +20 feature channels assembled by CNNv3Dataset.__getitem__: + [0-2] albedo.rgb f32 [0,1] + [3-4] normal.xy f32 oct-encoded [0,1] + [5] depth f32 [0,1] + [6-7] depth_grad.xy finite diff, signed + [8] mat_id f32 [0,1] + [9-11] prev.rgb f32 (zero during training) + [12-14] mip1.rgb pyrdown(albedo) + [15-17] mip2.rgb pyrdown(mip1) + [18] shadow f32 [0,1] + [19] transp f32 [0,1] + +Sample directory layout (per sample_xxx/): + albedo.png RGB uint8 + normal.png RG uint8 oct-encoded (128,128 = no normal) + depth.png R uint16 [0,65535] → [0,1] + matid.png R uint8 [0,255] → [0,1] + shadow.png R uint8 [0=dark, 255=lit] + transp.png R uint8 [0=opaque, 255=clear] + target.png RGBA uint8 +""" + +import random +from pathlib import Path +from typing import List, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +N_FEATURES = 20 + +GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy +CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp +TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb + +# --------------------------------------------------------------------------- +# Image I/O +# --------------------------------------------------------------------------- + +def load_rgb(path: Path) -> np.ndarray: + """Load PNG → (H,W,3) f32 [0,1].""" + return np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 + + +def load_rg(path: Path) -> np.ndarray: + """Load PNG RG channels → (H,W,2) f32 [0,1].""" + arr = np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 + return arr[..., :2] + + +def load_depth16(path: Path) -> np.ndarray: + """Load 16-bit greyscale depth PNG → (H,W) f32 [0,1].""" + arr = np.asarray(Image.open(path), dtype=np.float32) + if arr.max() > 1.0: + arr = arr / 65535.0 + return arr + + +def load_gray(path: Path) -> np.ndarray: + """Load 8-bit greyscale PNG → (H,W) f32 [0,1].""" + return np.asarray(Image.open(path).convert('L'), dtype=np.float32) / 255.0 + + +# --------------------------------------------------------------------------- +# Feature assembly helpers +# --------------------------------------------------------------------------- + +def pyrdown(img: np.ndarray) -> np.ndarray: + """2×2 average pool → half resolution. img: (H,W,C) f32.""" + h, w, _ = img.shape + h2, w2 = h // 2, w // 2 + t = img[:h2 * 2, :w2 * 2, :] + return 0.25 * (t[0::2, 0::2] + t[1::2, 0::2] + t[0::2, 1::2] + t[1::2, 1::2]) + + +def depth_gradient(depth: np.ndarray) -> np.ndarray: + """Central finite difference of depth map → (H,W,2) [dzdx, dzdy].""" + px = np.pad(depth, ((0, 0), (1, 1)), mode='edge') + py = np.pad(depth, ((1, 1), (0, 0)), mode='edge') + dzdx = (px[:, 2:] - px[:, :-2]) * 0.5 + dzdy = (py[2:, :] - py[:-2, :]) * 0.5 + return np.stack([dzdx, dzdy], axis=-1) + + +def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: + """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C).""" + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) + img = img.resize((w, h), Image.NEAREST) + return np.asarray(img, dtype=np.float32) / 255.0 + + +def assemble_features(albedo: np.ndarray, normal: np.ndarray, + depth: np.ndarray, matid: np.ndarray, + shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: + """Build (H,W,20) f32 feature tensor. + + prev set to zero (no temporal history during training). + mip1/mip2 computed from albedo. depth_grad computed via finite diff. + """ + h, w = albedo.shape[:2] + + mip1 = _upsample_nearest(pyrdown(albedo), h, w) + mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) + dgrad = depth_gradient(depth) + prev = np.zeros((h, w, 3), dtype=np.float32) + + return np.concatenate([ + albedo, # [0-2] albedo.rgb + normal, # [3-4] normal.xy + depth[..., None], # [5] depth + dgrad, # [6-7] depth_grad.xy + matid[..., None], # [8] mat_id + prev, # [9-11] prev.rgb + mip1, # [12-14] mip1.rgb + mip2, # [15-17] mip2.rgb + shadow[..., None], # [18] shadow + transp[..., None], # [19] transp + ], axis=-1).astype(np.float32) + + +def apply_channel_dropout(feat: np.ndarray, + p_geom: float = 0.3, + p_context: float = 0.2, + p_temporal: float = 0.5) -> np.ndarray: + """Zero out channel groups with given probabilities (in-place copy).""" + feat = feat.copy() + if random.random() < p_geom: + feat[..., GEOMETRIC_CHANNELS] = 0.0 + if random.random() < p_context: + feat[..., CONTEXT_CHANNELS] = 0.0 + if random.random() < p_temporal: + feat[..., TEMPORAL_CHANNELS] = 0.0 + return feat + + +# --------------------------------------------------------------------------- +# Salient point detection +# --------------------------------------------------------------------------- + +def detect_salient_points(albedo: np.ndarray, n: int, detector: str, + patch_size: int) -> List[Tuple[int, int]]: + """Return n (cx, cy) patch centres from albedo (H,W,3) f32 [0,1]. + + Detects up to 2n candidates via the chosen method, filters to valid patch + bounds, fills remainder with random points. + + detector: 'harris' | 'shi-tomasi' | 'fast' | 'gradient' | 'random' + """ + h, w = albedo.shape[:2] + half = patch_size // 2 + gray = cv2.cvtColor((albedo * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + + corners = None + if detector in ('harris', 'shi-tomasi'): + corners = cv2.goodFeaturesToTrack( + gray, n * 2, qualityLevel=0.01, minDistance=half, + useHarrisDetector=(detector == 'harris')) + elif detector == 'fast': + kps = cv2.FastFeatureDetector_create(threshold=20).detect(gray, None) + if kps: + pts = np.array([[kp.pt[0], kp.pt[1]] for kp in kps[:n * 2]]) + corners = pts.reshape(-1, 1, 2) + elif detector == 'gradient': + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + mag = np.sqrt(gx ** 2 + gy ** 2) + ys, xs = np.where(mag > np.percentile(mag, 95)) + if len(xs) > n * 2: + sel = np.random.choice(len(xs), n * 2, replace=False) + xs, ys = xs[sel], ys[sel] + if len(xs): + corners = np.stack([xs, ys], axis=1).reshape(-1, 1, 2).astype(np.float32) + # 'random' → corners stays None, falls through to random fill + + pts: List[Tuple[int, int]] = [] + if corners is not None: + for c in corners: + cx, cy = int(c[0][0]), int(c[0][1]) + if half <= cx < w - half and half <= cy < h - half: + pts.append((cx, cy)) + if len(pts) >= n: + break + + while len(pts) < n: + pts.append((random.randint(half, w - half - 1), + random.randint(half, h - half - 1))) + return pts[:n] + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +class CNNv3Dataset(Dataset): + """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories. + + Patch mode (default): extracts patch_size×patch_size crops centred on + salient points detected from albedo. Points are pre-cached at init. + + Full-image mode (--full-image): resizes entire image to image_size×image_size. + + Returns (feat, cond, target): + feat: (20, H, W) f32 + cond: (5,) f32 FiLM conditioning (random when augment=True) + target: (4, H, W) f32 RGBA [0,1] + """ + + def __init__(self, dataset_dir: str, + input_mode: str = 'simple', + patch_size: int = 64, + patches_per_image: int = 256, + image_size: int = 256, + full_image: bool = False, + channel_dropout_p: float = 0.3, + detector: str = 'harris', + augment: bool = True): + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.image_size = image_size + self.full_image = full_image + self.channel_dropout_p = channel_dropout_p + self.detector = detector + self.augment = augment + + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) + if not self.samples: + raise RuntimeError(f"No samples found in {search_dir}") + + # Pre-cache salient patch centres (albedo-only load — cheap) + self._patch_centers: List[List[Tuple[int, int]]] = [] + if not full_image: + print(f"[CNNv3Dataset] Detecting salient points " + f"(detector={detector}, patch={patch_size}×{patch_size}) …") + for sd in self.samples: + pts = detect_salient_points( + load_rgb(sd / 'albedo.png'), + patches_per_image, detector, patch_size) + self._patch_centers.append(pts) + + print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " + f"patch={patch_size} full_image={full_image}") + + def __len__(self): + if self.full_image: + return len(self.samples) + return len(self.samples) * self.patches_per_image + + def _load_sample(self, sd: Path): + albedo = load_rgb(sd / 'albedo.png') + normal = load_rg(sd / 'normal.png') + depth = load_depth16(sd / 'depth.png') + matid = load_gray(sd / 'matid.png') + shadow = load_gray(sd / 'shadow.png') + transp = load_gray(sd / 'transp.png') + target = np.asarray( + Image.open(sd / 'target.png').convert('RGBA'), + dtype=np.float32) / 255.0 + return albedo, normal, depth, matid, shadow, transp, target + + def __getitem__(self, idx): + if self.full_image: + sample_idx = idx + sd = self.samples[idx] + else: + sample_idx = idx // self.patches_per_image + sd = self.samples[sample_idx] + + albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd) + h, w = albedo.shape[:2] + + if self.full_image: + sz = self.image_size + + def _resize_rgb(a): + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) + return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 + + def _resize_gray(a): + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L') + return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 + + albedo = _resize_rgb(albedo) + normal = _resize_rgb(np.concatenate( + [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] + depth = _resize_gray(depth) + matid = _resize_gray(matid) + shadow = _resize_gray(shadow) + transp = _resize_gray(transp) + target = _resize_rgb(target) + else: + ps = self.patch_size + half = ps // 2 + cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image] + cx = max(half, min(cx, w - half)) + cy = max(half, min(cy, h - half)) + sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps)) + + albedo = albedo[sl] + normal = normal[sl] + depth = depth[sl] + matid = matid[sl] + shadow = shadow[sl] + transp = transp[sl] + target = target[sl] + + feat = assemble_features(albedo, normal, depth, matid, shadow, transp) + + if self.augment: + feat = apply_channel_dropout(feat, + p_geom=self.channel_dropout_p, + p_context=self.channel_dropout_p * 0.67, + p_temporal=0.5) + cond = np.random.rand(5).astype(np.float32) + else: + cond = np.zeros(5, dtype=np.float32) + + return (torch.from_numpy(feat).permute(2, 0, 1), # (20,H,W) + torch.from_numpy(cond), # (5,) + torch.from_numpy(target).permute(2, 0, 1)) # (4,H,W) |
