summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xscripts/train_cnn_v2_full.sh11
1 files changed, 11 insertions, 0 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 766bd8b..078ea28 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -12,6 +12,7 @@
# TRAINING PARAMETERS:
# --epochs N Training epochs (default: 200)
# --batch-size N Batch size (default: 16)
+# --lr FLOAT Learning rate (default: 1e-3)
# --checkpoint-every N Checkpoint interval (default: 50)
# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
# --num-layers N Number of layers (default: 3)
@@ -71,6 +72,7 @@ VALIDATION_DIR="validation_results"
EPOCHS=200
CHECKPOINT_EVERY=50
BATCH_SIZE=16
+LEARNING_RATE=1e-3
PATCH_SIZE=8
PATCHES_PER_IMAGE=256
DETECTOR="harris"
@@ -166,6 +168,14 @@ while [[ $# -gt 0 ]]; do
GRAYSCALE_LOSS=true
shift
;;
+ --lr)
+ if [ -z "$2" ]; then
+ echo "Error: --lr requires a float argument"
+ exit 1
+ fi
+ LEARNING_RATE="$2"
+ shift 2
+ ;;
--patch-size)
if [ -z "$2" ]; then
echo "Error: --patch-size requires a number argument"
@@ -305,6 +315,7 @@ python3 training/train_cnn_v2.py \
--mip-level "$MIP_LEVEL" \
--epochs "$EPOCHS" \
--batch-size "$BATCH_SIZE" \
+ --lr "$LEARNING_RATE" \
--checkpoint-dir "$CHECKPOINT_DIR" \
--checkpoint-every "$CHECKPOINT_EVERY" \
$([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")