diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 4fc3a6c..c249947 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -8,7 +8,7 @@ Usage: python3 train_cnn.py --input input_dir/ --target target_dir/ [options] Example: - python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100 + python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 """ import torch @@ -236,6 +236,8 @@ def train(args): # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] + if len(kernel_sizes) == 1 and args.layers > 1: + kernel_sizes = kernel_sizes * args.layers # Create model model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) |
