summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-22 11:03:32 +0100
committerskal <pascal.massimino@gmail.com>2026-03-22 11:03:32 +0100
commite6d847b711d68c9eeaa589e632b8f6edf9fec2f2 (patch)
tree3a1d6a0b7bb47b7e815ec33e70e22a7670e71b24 /cnn_v3
parent0d255535bbc135b5455a21701c31fdeecbe812d9 (diff)
perf(cnn_v3): cache dataset images at init to avoid per-patch disk I/O
handoff(Gemini): CNNv3Dataset now loads all samples once in __init__ into self._cache; __getitem__ reads from cache instead of reloading PNGs each call. Eliminates N×patches_per_image file loads per epoch.
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index 5b43a4d..b32e548 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -247,15 +247,17 @@ class CNNv3Dataset(Dataset):
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
- # Pre-cache salient patch centres (albedo-only load — cheap)
+ # Pre-load all sample data into memory
+ print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …")
+ self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples]
+
+ # Pre-cache salient patch centres (albedo already loaded above)
self._patch_centers: List[List[Tuple[int, int]]] = []
if not full_image:
print(f"[CNNv3Dataset] Detecting salient points "
f"(detector={detector}, patch={patch_size}×{patch_size}) …")
- for sd in self.samples:
- pts = detect_salient_points(
- load_rgb(sd / 'albedo.png'),
- patches_per_image, detector, patch_size)
+ for sd, (albedo, *_) in zip(self.samples, self._cache):
+ pts = detect_salient_points(albedo, patches_per_image, detector, patch_size)
self._patch_centers.append(pts)
print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
@@ -288,7 +290,7 @@ class CNNv3Dataset(Dataset):
sample_idx = idx // self.patches_per_image
sd = self.samples[sample_idx]
- albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd)
+ albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
if self.full_image: