diff options
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index de10d6a..31cfd9d 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -104,6 +104,10 @@ def train(args): enc_channels = [int(c) for c in args.enc_channels.split(',')] print(f"Device: {device}") + if args.single_sample: + args.full_image = True + args.batch_size = 1 + dataset = CNNv3Dataset( dataset_dir=args.input, input_mode=args.input_mode, @@ -115,6 +119,7 @@ def train(args): detector=args.detector, augment=True, patch_search_window=args.patch_search_window, + single_sample=args.single_sample, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) @@ -222,6 +227,8 @@ def main(): p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') # Dataset + p.add_argument('--single-sample', default='', metavar='DIR', + help='Train on a single sample directory; implies --full-image and --batch-size 1') p.add_argument('--input', default='training/dataset', help='Dataset root (contains full/ or simple/ subdirs)') p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], |
