diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 08:27:39 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 08:27:39 +0100 |
| commit | 3e4fece8fce11b368b4c7bab284242bf18e6a0b1 (patch) | |
| tree | 108682f727f7668e1df346563f4576d7e567dcd2 /cnn_v3/training/cnn_v3_utils.py | |
| parent | 64095c683f15e8bd7c19d32041fcc81b1bd6c214 (diff) | |
feat(cnn_v3/training): add --single-sample option + doc fixes
- train_cnn_v3.py: --single-sample <dir> implies --full-image + --batch-size 1
- cnn_v3_utils.py: CNNv3Dataset accepts single_sample= kwarg (explicit override)
- HOWTO.md: document --single-sample workflow, fix pack_photo_sample.py usage (--target required)
- HOW_TO_CNN.md: fix GBufferEffect seq input (prev_cnn→source), fix binary name (demo→demo64k), add --resume to flag table, remove stale "pack without target" block
handoff(Gemini): --single-sample <dir> added to train_cnn_v3.py; docs audited and corrected
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index bef4091..50707a2 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset): channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True, - patch_search_window: int = 0): + patch_search_window: int = 0, + single_sample: str = ''): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size @@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset): self.augment = augment self.patch_search_window = patch_search_window - root = Path(dataset_dir) - subdir = 'full' if input_mode == 'full' else 'simple' - search_dir = root / subdir - if not search_dir.exists(): - search_dir = root - - self.samples = sorted([ - d for d in search_dir.iterdir() - if d.is_dir() and (d / 'albedo.png').exists() - ]) + if single_sample: + self.samples = [Path(single_sample)] + else: + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") |
