summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cnn_v3/training/cnn_v3_utils.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 5b43a4d..b32e548 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -247,15 +247,17 @@ class CNNv3Dataset(Dataset):
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
- # Pre-cache salient patch centres (albedo-only load — cheap)
+ # Pre-load all sample data into memory
+ print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …")
+ self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples]
+
+ # Pre-cache salient patch centres (albedo already loaded above)
self._patch_centers: List[List[Tuple[int, int]]] = []
if not full_image:
print(f"[CNNv3Dataset] Detecting salient points "
f"(detector={detector}, patch={patch_size}×{patch_size}) …")
- for sd in self.samples:
- pts = detect_salient_points(
- load_rgb(sd / 'albedo.png'),
- patches_per_image, detector, patch_size)
+ for sd, (albedo, *_) in zip(self.samples, self._cache):
+ pts = detect_salient_points(albedo, patches_per_image, detector, patch_size)
self._patch_centers.append(pts)
print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
@@ -288,7 +290,7 @@ class CNNv3Dataset(Dataset):
sample_idx = idx // self.patches_per_image
sd = self.samples[sample_idx]
- albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd)
+ albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
if self.full_image: