From e6d847b711d68c9eeaa589e632b8f6edf9fec2f2 Mon Sep 17 00:00:00 2001 From: skal Date: Sun, 22 Mar 2026 11:03:32 +0100 Subject: perf(cnn_v3): cache dataset images at init to avoid per-patch disk I/O MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit handoff(Gemini): CNNv3Dataset now loads all samples once in __init__ into self._cache; __getitem__ reads from cache instead of reloading PNGs each call. Eliminates N×patches_per_image file loads per epoch. --- cnn_v3/training/cnn_v3_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'cnn_v3') diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index 5b43a4d..b32e548 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -247,15 +247,17 @@ class CNNv3Dataset(Dataset): if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") - # Pre-cache salient patch centres (albedo-only load — cheap) + # Pre-load all sample data into memory + print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …") + self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples] + + # Pre-cache salient patch centres (albedo already loaded above) self._patch_centers: List[List[Tuple[int, int]]] = [] if not full_image: print(f"[CNNv3Dataset] Detecting salient points " f"(detector={detector}, patch={patch_size}×{patch_size}) …") - for sd in self.samples: - pts = detect_salient_points( - load_rgb(sd / 'albedo.png'), - patches_per_image, detector, patch_size) + for sd, (albedo, *_) in zip(self.samples, self._cache): + pts = detect_salient_points(albedo, patches_per_image, detector, patch_size) self._patch_centers.append(pts) print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " @@ -288,7 +290,7 @@ class CNNv3Dataset(Dataset): sample_idx = idx // self.patches_per_image sd = self.samples[sample_idx] - albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd) + albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: -- cgit v1.2.3