summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py45
1 files changed, 28 insertions, 17 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..68c0798 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
def assemble_features(albedo: np.ndarray, normal: np.ndarray,
depth: np.ndarray, matid: np.ndarray,
- shadow: np.ndarray, transp: np.ndarray) -> np.ndarray:
+ shadow: np.ndarray, transp: np.ndarray,
+ prev: np.ndarray | None = None) -> np.ndarray:
"""Build (H,W,20) f32 feature tensor.
- prev set to zero (no temporal history during training).
+ prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros.
mip1/mip2 computed from albedo. depth_grad computed via finite diff.
dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow.
"""
@@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
mip1 = _upsample_nearest(pyrdown(albedo), h, w)
mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
dgrad = depth_gradient(depth)
- prev = np.zeros((h, w, 3), dtype=np.float32)
+ if prev is None:
+ prev = np.zeros((h, w, 3), dtype=np.float32)
nor3 = oct_decode(normal)
diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1))
dif = diffuse * shadow
@@ -286,7 +288,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +299,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
@@ -345,11 +350,13 @@ class CNNv3Dataset(Dataset):
shadow = load_gray(sd / 'shadow.png')
transp = load_gray(sd / 'transp.png')
h, w = albedo.shape[:2]
+ prev_path = sd / 'prev.png'
+ prev = load_rgb(prev_path) if prev_path.exists() else None
target_img = Image.open(sd / 'target.png').convert('RGBA')
if target_img.size != (w, h):
target_img = target_img.resize((w, h), Image.LANCZOS)
target = np.asarray(target_img, dtype=np.float32) / 255.0
- return albedo, normal, depth, matid, shadow, transp, target
+ return albedo, normal, depth, matid, shadow, transp, prev, target
def __getitem__(self, idx):
if self.full_image:
@@ -357,7 +364,7 @@ class CNNv3Dataset(Dataset):
else:
sample_idx = idx // self.patches_per_image
- albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
+ albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
if self.full_image:
@@ -379,6 +386,8 @@ class CNNv3Dataset(Dataset):
matid = _resize_gray(matid)
shadow = _resize_gray(shadow)
transp = _resize_gray(transp)
+ if prev is not None:
+ prev = _resize_img(prev)
target = _resize_img(target)
else:
ps = self.patch_size
@@ -395,6 +404,8 @@ class CNNv3Dataset(Dataset):
matid = matid[sl]
shadow = shadow[sl]
transp = transp[sl]
+ if prev is not None:
+ prev = prev[sl]
# Apply cached target offset (if search was enabled at init).
if self._target_offsets:
@@ -405,7 +416,7 @@ class CNNv3Dataset(Dataset):
else:
target = target[sl]
- feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
+ feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev)
if self.augment:
feat = apply_channel_dropout(feat,