summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 10:37:29 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 10:37:29 +0100
commit5515301560451549f228867a72ca850cffeb3714 (patch)
tree558b139666e24d818e2201bf9524ebb6a04765d4 /training/train_cnn.py
parentee47830f43d575dc917ad480e180c3be7ea23b3a (diff)
fix: Auto-expand single kernel size to all layers in training script
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 4fc3a6c..c249947 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -8,7 +8,7 @@ Usage:
python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
Example:
- python3 train_cnn.py --input ./training/input --target ./training/output --layers 3 --epochs 100
+ python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
"""
import torch
@@ -236,6 +236,8 @@ def train(args):
# Parse kernel sizes
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+ if len(kernel_sizes) == 1 and args.layers > 1:
+ kernel_sizes = kernel_sizes * args.layers
# Create model
model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)