diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 08:27:39 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 08:27:39 +0100 |
| commit | 3e4fece8fce11b368b4c7bab284242bf18e6a0b1 (patch) | |
| tree | 108682f727f7668e1df346563f4576d7e567dcd2 /cnn_v3/training | |
| parent | 64095c683f15e8bd7c19d32041fcc81b1bd6c214 (diff) | |
feat(cnn_v3/training): add --single-sample option + doc fixes
- train_cnn_v3.py: --single-sample <dir> implies --full-image + --batch-size 1
- cnn_v3_utils.py: CNNv3Dataset accepts single_sample= kwarg (explicit override)
- HOWTO.md: document --single-sample workflow, fix pack_photo_sample.py usage (--target required)
- HOW_TO_CNN.md: fix GBufferEffect seq input (prev_cnn→source), fix binary name (demo→demo64k), add --resume to flag table, remove stale "pack without target" block
handoff(Gemini): --single-sample <dir> added to train_cnn_v3.py; docs audited and corrected
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 25 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 7 |
2 files changed, 21 insertions, 11 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index bef4091..50707a2 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset): channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True, - patch_search_window: int = 0): + patch_search_window: int = 0, + single_sample: str = ''): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size @@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset): self.augment = augment self.patch_search_window = patch_search_window - root = Path(dataset_dir) - subdir = 'full' if input_mode == 'full' else 'simple' - search_dir = root / subdir - if not search_dir.exists(): - search_dir = root - - self.samples = sorted([ - d for d in search_dir.iterdir() - if d.is_dir() and (d / 'albedo.png').exists() - ]) + if single_sample: + self.samples = [Path(single_sample)] + else: + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index de10d6a..31cfd9d 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -104,6 +104,10 @@ def train(args): enc_channels = [int(c) for c in args.enc_channels.split(',')] print(f"Device: {device}") + if args.single_sample: + args.full_image = True + args.batch_size = 1 + dataset = CNNv3Dataset( dataset_dir=args.input, input_mode=args.input_mode, @@ -115,6 +119,7 @@ def train(args): detector=args.detector, augment=True, patch_search_window=args.patch_search_window, + single_sample=args.single_sample, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) @@ -222,6 +227,8 @@ def main(): p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') # Dataset + p.add_argument('--single-sample', default='', metavar='DIR', + help='Train on a single sample directory; implies --full-image and --batch-size 1') p.add_argument('--input', default='training/dataset', help='Dataset root (contains full/ or simple/ subdirs)') p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], |
