diff options
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index bef4091..50707a2 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset): channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True, - patch_search_window: int = 0): + patch_search_window: int = 0, + single_sample: str = ''): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size @@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset): self.augment = augment self.patch_search_window = patch_search_window - root = Path(dataset_dir) - subdir = 'full' if input_mode == 'full' else 'simple' - search_dir = root / subdir - if not search_dir.exists(): - search_dir = root - - self.samples = sorted([ - d for d in search_dir.iterdir() - if d.is_dir() and (d / 'albedo.png').exists() - ]) + if single_sample: + self.samples = [Path(single_sample)] + else: + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") |
