summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py25
-rw-r--r--cnn_v3/training/train_cnn_v3.py7
2 files changed, 21 insertions, 11 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..50707a2 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index de10d6a..31cfd9d 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -104,6 +104,10 @@ def train(args):
enc_channels = [int(c) for c in args.enc_channels.split(',')]
print(f"Device: {device}")
+ if args.single_sample:
+ args.full_image = True
+ args.batch_size = 1
+
dataset = CNNv3Dataset(
dataset_dir=args.input,
input_mode=args.input_mode,
@@ -115,6 +119,7 @@ def train(args):
detector=args.detector,
augment=True,
patch_search_window=args.patch_search_window,
+ single_sample=args.single_sample,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -222,6 +227,8 @@ def main():
p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
# Dataset
+ p.add_argument('--single-sample', default='', metavar='DIR',
+ help='Train on a single sample directory; implies --full-image and --batch-size 1')
p.add_argument('--input', default='training/dataset',
help='Dataset root (contains full/ or simple/ subdirs)')
p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],