diff options
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index bef4091..68c0798 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: def assemble_features(albedo: np.ndarray, normal: np.ndarray, depth: np.ndarray, matid: np.ndarray, - shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: + shadow: np.ndarray, transp: np.ndarray, + prev: np.ndarray | None = None) -> np.ndarray: """Build (H,W,20) f32 feature tensor. - prev set to zero (no temporal history during training). + prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros. mip1/mip2 computed from albedo. depth_grad computed via finite diff. dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow. """ @@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray, mip1 = _upsample_nearest(pyrdown(albedo), h, w) mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) dgrad = depth_gradient(depth) - prev = np.zeros((h, w, 3), dtype=np.float32) + if prev is None: + prev = np.zeros((h, w, 3), dtype=np.float32) nor3 = oct_decode(normal) diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1)) dif = diffuse * shadow @@ -286,7 +288,8 @@ class CNNv3Dataset(Dataset): channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True, - patch_search_window: int = 0): + patch_search_window: int = 0, + single_sample: str = ''): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size @@ -296,16 +299,18 @@ class CNNv3Dataset(Dataset): self.augment = augment self.patch_search_window = patch_search_window - root = Path(dataset_dir) - subdir = 'full' if input_mode == 'full' else 'simple' - search_dir = root / subdir - if not search_dir.exists(): - search_dir = root - - self.samples = sorted([ - d for d in search_dir.iterdir() - if d.is_dir() and (d / 'albedo.png').exists() - ]) + if single_sample: + self.samples = [Path(single_sample)] + else: + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") @@ -345,11 +350,13 @@ class CNNv3Dataset(Dataset): shadow = load_gray(sd / 'shadow.png') transp = load_gray(sd / 'transp.png') h, w = albedo.shape[:2] + prev_path = sd / 'prev.png' + prev = load_rgb(prev_path) if prev_path.exists() else None target_img = Image.open(sd / 'target.png').convert('RGBA') if target_img.size != (w, h): target_img = target_img.resize((w, h), Image.LANCZOS) target = np.asarray(target_img, dtype=np.float32) / 255.0 - return albedo, normal, depth, matid, shadow, transp, target + return albedo, normal, depth, matid, shadow, transp, prev, target def __getitem__(self, idx): if self.full_image: @@ -357,7 +364,7 @@ class CNNv3Dataset(Dataset): else: sample_idx = idx // self.patches_per_image - albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] + albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: @@ -379,6 +386,8 @@ class CNNv3Dataset(Dataset): matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) + if prev is not None: + prev = _resize_img(prev) target = _resize_img(target) else: ps = self.patch_size @@ -395,6 +404,8 @@ class CNNv3Dataset(Dataset): matid = matid[sl] shadow = shadow[sl] transp = transp[sl] + if prev is not None: + prev = prev[sl] # Apply cached target offset (if search was enabled at init). if self._target_offsets: @@ -405,7 +416,7 @@ class CNNv3Dataset(Dataset): else: target = target[sl] - feat = assemble_features(albedo, normal, depth, matid, shadow, transp) + feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev) if self.augment: feat = apply_channel_dropout(feat, |
