diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-22 11:03:32 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-22 11:03:32 +0100 |
| commit | e6d847b711d68c9eeaa589e632b8f6edf9fec2f2 (patch) | |
| tree | 3a1d6a0b7bb47b7e815ec33e70e22a7670e71b24 /cnn_v3/training/cnn_v3_utils.py | |
| parent | 0d255535bbc135b5455a21701c31fdeecbe812d9 (diff) | |
perf(cnn_v3): cache dataset images at init to avoid per-patch disk I/O
handoff(Gemini): CNNv3Dataset now loads all samples once in __init__ into
self._cache; __getitem__ reads from cache instead of reloading PNGs each
call. Eliminates N×patches_per_image file loads per epoch.
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index 5b43a4d..b32e548 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -247,15 +247,17 @@ class CNNv3Dataset(Dataset): if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") - # Pre-cache salient patch centres (albedo-only load — cheap) + # Pre-load all sample data into memory + print(f"[CNNv3Dataset] Loading {len(self.samples)} samples into memory …") + self._cache: List[tuple] = [self._load_sample(sd) for sd in self.samples] + + # Pre-cache salient patch centres (albedo already loaded above) self._patch_centers: List[List[Tuple[int, int]]] = [] if not full_image: print(f"[CNNv3Dataset] Detecting salient points " f"(detector={detector}, patch={patch_size}×{patch_size}) …") - for sd in self.samples: - pts = detect_salient_points( - load_rgb(sd / 'albedo.png'), - patches_per_image, detector, patch_size) + for sd, (albedo, *_) in zip(self.samples, self._cache): + pts = detect_salient_points(albedo, patches_per_image, detector, patch_size) self._patch_centers.append(pts) print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " @@ -288,7 +290,7 @@ class CNNv3Dataset(Dataset): sample_idx = idx // self.patches_per_image sd = self.samples[sample_idx] - albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd) + albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: |
