summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 4fc3a6c..c249947 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -8,7 +8,7 @@ Usage:
python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
Example:
- python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100
+ python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
"""
import torch
@@ -236,6 +236,8 @@ def train(args):
# Parse kernel sizes
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+ if len(kernel_sizes) == 1 and args.layers > 1:
+ kernel_sizes = kernel_sizes * args.layers
# Create model
model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)