summaryrefslogtreecommitdiff
path: root/scripts/train_cnn_v2_full.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
-rwxr-xr-xscripts/train_cnn_v2_full.sh191
1 files changed, 160 insertions, 31 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 6468d12..8b09191 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -3,35 +3,74 @@
# Train → Export → Build → Validate
# Usage: ./train_cnn_v2_full.sh [OPTIONS]
#
-# OPTIONS:
+# MODES:
# (none) Run complete pipeline: train → export → build → validate
# --validate Validate only (skip training, use existing weights)
# --validate CHECKPOINT Validate with specific checkpoint file
# --export-only CHECKPOINT Export weights only (skip training, build, validation)
-# --mip-level LEVEL Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)
+#
+# TRAINING PARAMETERS:
+# --epochs N Training epochs (default: 200)
+# --batch-size N Batch size (default: 16)
+# --checkpoint-every N Checkpoint interval (default: 50)
+# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
+# --num-layers N Number of layers (default: 3)
+# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0)
+#
+# PATCH PARAMETERS:
+# --patch-size N Patch size (default: 8)
+# --patches-per-image N Patches per image (default: 256)
+# --detector TYPE Detector: harris|fast|shi-tomasi|gradient (default: harris)
+# --full-image Use full-image training (disables patch mode)
+# --image-size N Image size for full-image mode (default: 256)
+#
+# DIRECTORIES:
+# --input DIR Input directory (default: training/input)
+# --target DIR Target directory (default: training/target_1)
+# --checkpoint-dir DIR Checkpoint directory (default: checkpoints)
+# --validation-dir DIR Validation directory (default: validation_results)
+#
+# OTHER:
# --help Show this help message
#
# Examples:
# ./train_cnn_v2_full.sh
+# ./train_cnn_v2_full.sh --epochs 500 --batch-size 32
# ./train_cnn_v2_full.sh --validate
# ./train_cnn_v2_full.sh --validate checkpoints/checkpoint_epoch_50.pth
# ./train_cnn_v2_full.sh --export-only checkpoints/checkpoint_epoch_100.pth
-# ./train_cnn_v2_full.sh --mip-level 1
+# ./train_cnn_v2_full.sh --mip-level 1 --kernel-sizes 3,5,3
set -e
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_ROOT"
+# Default configuration
+INPUT_DIR="training/input"
+TARGET_DIR="training/target_1"
+CHECKPOINT_DIR="checkpoints"
+VALIDATION_DIR="validation_results"
+EPOCHS=200
+CHECKPOINT_EVERY=50
+BATCH_SIZE=16
+PATCH_SIZE=8
+PATCHES_PER_IMAGE=256
+DETECTOR="harris"
+KERNEL_SIZES="3,3,3"
+NUM_LAYERS=3
+MIP_LEVEL=0
+FULL_IMAGE_MODE=false
+IMAGE_SIZE=256
+
# Parse arguments
VALIDATE_ONLY=false
VALIDATE_CHECKPOINT=""
EXPORT_ONLY=false
EXPORT_CHECKPOINT=""
-MIP_LEVEL=0
if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
- head -21 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//'
+ head -47 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//'
exit 0
fi
@@ -56,6 +95,46 @@ while [[ $# -gt 0 ]]; do
shift
fi
;;
+ --epochs)
+ if [ -z "$2" ]; then
+ echo "Error: --epochs requires a number argument"
+ exit 1
+ fi
+ EPOCHS="$2"
+ shift 2
+ ;;
+ --batch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --batch-size requires a number argument"
+ exit 1
+ fi
+ BATCH_SIZE="$2"
+ shift 2
+ ;;
+ --checkpoint-every)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-every requires a number argument"
+ exit 1
+ fi
+ CHECKPOINT_EVERY="$2"
+ shift 2
+ ;;
+ --kernel-sizes)
+ if [ -z "$2" ]; then
+ echo "Error: --kernel-sizes requires a comma-separated list"
+ exit 1
+ fi
+ KERNEL_SIZES="$2"
+ shift 2
+ ;;
+ --num-layers)
+ if [ -z "$2" ]; then
+ echo "Error: --num-layers requires a number argument"
+ exit 1
+ fi
+ NUM_LAYERS="$2"
+ shift 2
+ ;;
--mip-level)
if [ -z "$2" ]; then
echo "Error: --mip-level requires a level argument (0-3)"
@@ -64,6 +143,74 @@ while [[ $# -gt 0 ]]; do
MIP_LEVEL="$2"
shift 2
;;
+ --patch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --patch-size requires a number argument"
+ exit 1
+ fi
+ PATCH_SIZE="$2"
+ shift 2
+ ;;
+ --patches-per-image)
+ if [ -z "$2" ]; then
+ echo "Error: --patches-per-image requires a number argument"
+ exit 1
+ fi
+ PATCHES_PER_IMAGE="$2"
+ shift 2
+ ;;
+ --detector)
+ if [ -z "$2" ]; then
+ echo "Error: --detector requires a type argument"
+ exit 1
+ fi
+ DETECTOR="$2"
+ shift 2
+ ;;
+ --full-image)
+ FULL_IMAGE_MODE=true
+ shift
+ ;;
+ --image-size)
+ if [ -z "$2" ]; then
+ echo "Error: --image-size requires a number argument"
+ exit 1
+ fi
+ IMAGE_SIZE="$2"
+ shift 2
+ ;;
+ --input)
+ if [ -z "$2" ]; then
+ echo "Error: --input requires a directory argument"
+ exit 1
+ fi
+ INPUT_DIR="$2"
+ shift 2
+ ;;
+ --target)
+ if [ -z "$2" ]; then
+ echo "Error: --target requires a directory argument"
+ exit 1
+ fi
+ TARGET_DIR="$2"
+ shift 2
+ ;;
+ --checkpoint-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-dir requires a directory argument"
+ exit 1
+ fi
+ CHECKPOINT_DIR="$2"
+ shift 2
+ ;;
+ --validation-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --validation-dir requires a directory argument"
+ exit 1
+ fi
+ VALIDATION_DIR="$2"
+ shift 2
+ ;;
*)
echo "Unknown option: $1"
exit 1
@@ -71,27 +218,12 @@ while [[ $# -gt 0 ]]; do
esac
done
-# Configuration
-INPUT_DIR="training/input"
-TARGET_DIR="training/target_2"
-CHECKPOINT_DIR="checkpoints"
-VALIDATION_DIR="validation_results"
-EPOCHS=200
-CHECKPOINT_EVERY=50
-BATCH_SIZE=16
-
-# Patch-based training (default)
-PATCH_SIZE=8
-PATCHES_PER_IMAGE=256
-DETECTOR="harris"
-
-# Full-image training (alternative - uncomment to use)
-# FULL_IMAGE="--full-image"
-# IMAGE_SIZE=256
-
-KERNEL_SIZES="3,3,3" # Comma-separated per-layer kernel sizes
-NUM_LAYERS=3
-# MIP_LEVEL set via --mip-level argument (default: 0)
+# Build training arguments
+if [ "$FULL_IMAGE_MODE" = true ]; then
+ TRAINING_MODE_ARGS="--full-image --image-size $IMAGE_SIZE"
+else
+ TRAINING_MODE_ARGS="--patch-size $PATCH_SIZE --patches-per-image $PATCHES_PER_IMAGE --detector $DETECTOR"
+fi
# Handle export-only mode
if [ "$EXPORT_ONLY" = true ]; then
@@ -138,17 +270,14 @@ if [ "$VALIDATE_ONLY" = false ]; then
python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
- --patch-size $PATCH_SIZE \
- --patches-per-image $PATCHES_PER_IMAGE \
- --detector $DETECTOR \
+ $TRAINING_MODE_ARGS \
--kernel-sizes $KERNEL_SIZES \
--num-layers $NUM_LAYERS \
--mip-level $MIP_LEVEL \
--epochs $EPOCHS \
--batch-size $BATCH_SIZE \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY \
- $FULL_IMAGE
+ --checkpoint-every $CHECKPOINT_EVERY
if [ $? -ne 0 ]; then
echo "Error: Training failed"