diff options
Diffstat (limited to 'cnn_v3')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index 50707a2..68c0798 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: def assemble_features(albedo: np.ndarray, normal: np.ndarray, depth: np.ndarray, matid: np.ndarray, - shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: + shadow: np.ndarray, transp: np.ndarray, + prev: np.ndarray | None = None) -> np.ndarray: """Build (H,W,20) f32 feature tensor. - prev set to zero (no temporal history during training). + prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros. mip1/mip2 computed from albedo. depth_grad computed via finite diff. dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow. """ @@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray, mip1 = _upsample_nearest(pyrdown(albedo), h, w) mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) dgrad = depth_gradient(depth) - prev = np.zeros((h, w, 3), dtype=np.float32) + if prev is None: + prev = np.zeros((h, w, 3), dtype=np.float32) nor3 = oct_decode(normal) diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1)) dif = diffuse * shadow @@ -348,11 +350,13 @@ class CNNv3Dataset(Dataset): shadow = load_gray(sd / 'shadow.png') transp = load_gray(sd / 'transp.png') h, w = albedo.shape[:2] + prev_path = sd / 'prev.png' + prev = load_rgb(prev_path) if prev_path.exists() else None target_img = Image.open(sd / 'target.png').convert('RGBA') if target_img.size != (w, h): target_img = target_img.resize((w, h), Image.LANCZOS) target = np.asarray(target_img, dtype=np.float32) / 255.0 - return albedo, normal, depth, matid, shadow, transp, target + return albedo, normal, depth, matid, shadow, transp, prev, target def __getitem__(self, idx): if self.full_image: @@ -360,7 +364,7 @@ class CNNv3Dataset(Dataset): else: sample_idx = idx // self.patches_per_image - albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] + albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: @@ -382,6 +386,8 @@ class CNNv3Dataset(Dataset): matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) + if prev is not None: + prev = _resize_img(prev) target = _resize_img(target) else: ps = self.patch_size @@ -398,6 +404,8 @@ class CNNv3Dataset(Dataset): matid = matid[sl] shadow = shadow[sl] transp = transp[sl] + if prev is not None: + prev = prev[sl] # Apply cached target offset (if search was enabled at init). if self._target_offsets: @@ -408,7 +416,7 @@ class CNNv3Dataset(Dataset): else: target = target[sl] - feat = assemble_features(albedo, normal, depth, matid, shadow, transp) + feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev) if self.augment: feat = apply_channel_dropout(feat, |
