summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 08:27:39 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 08:27:39 +0100
commit3e4fece8fce11b368b4c7bab284242bf18e6a0b1 (patch)
tree108682f727f7668e1df346563f4576d7e567dcd2 /cnn_v3/training/cnn_v3_utils.py
parent64095c683f15e8bd7c19d32041fcc81b1bd6c214 (diff)
feat(cnn_v3/training): add --single-sample option + doc fixes
- train_cnn_v3.py: --single-sample <dir> implies --full-image + --batch-size 1 - cnn_v3_utils.py: CNNv3Dataset accepts single_sample= kwarg (explicit override) - HOWTO.md: document --single-sample workflow, fix pack_photo_sample.py usage (--target required) - HOW_TO_CNN.md: fix GBufferEffect seq input (prev_cnn→source), fix binary name (demo→demo64k), add --resume to flag table, remove stale "pack without target" block handoff(Gemini): --single-sample <dir> added to train_cnn_v3.py; docs audited and corrected
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..50707a2 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")