summaryrefslogtreecommitdiff
path: root/scripts/train_cnn_v2_full.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
-rwxr-xr-xscripts/train_cnn_v2_full.sh16
1 files changed, 15 insertions, 1 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 8b09191..e444f20 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -16,6 +16,7 @@
# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
# --num-layers N Number of layers (default: 3)
# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0)
+# --grayscale-loss Compute loss on grayscale instead of RGBA
#
# PATCH PARAMETERS:
# --patch-size N Patch size (default: 8)
@@ -60,6 +61,7 @@ DETECTOR="harris"
KERNEL_SIZES="3,3,3"
NUM_LAYERS=3
MIP_LEVEL=0
+GRAYSCALE_LOSS=false
FULL_IMAGE_MODE=false
IMAGE_SIZE=256
@@ -143,6 +145,10 @@ while [[ $# -gt 0 ]]; do
MIP_LEVEL="$2"
shift 2
;;
+ --grayscale-loss)
+ GRAYSCALE_LOSS=true
+ shift
+ ;;
--patch-size)
if [ -z "$2" ]; then
echo "Error: --patch-size requires a number argument"
@@ -267,6 +273,13 @@ fi
if [ "$VALIDATE_ONLY" = false ]; then
# Step 1: Train model
echo "[1/4] Training CNN v2 model..."
+
+# Build optional flags
+OPTIONAL_FLAGS=""
+if [ "$GRAYSCALE_LOSS" = true ]; then
+ OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss"
+fi
+
python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
@@ -277,7 +290,8 @@ python3 training/train_cnn_v2.py \
--epochs $EPOCHS \
--batch-size $BATCH_SIZE \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY
+ --checkpoint-every $CHECKPOINT_EVERY \
+ $OPTIONAL_FLAGS
if [ $? -ne 0 ]; then
echo "Error: Training failed"