diff options
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
| -rwxr-xr-x | scripts/train_cnn_v2_full.sh | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index 8b09191..e444f20 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -16,6 +16,7 @@ # --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3) # --num-layers N Number of layers (default: 3) # --mip-level N Mip level for p0-p3 features: 0-3 (default: 0) +# --grayscale-loss Compute loss on grayscale instead of RGBA # # PATCH PARAMETERS: # --patch-size N Patch size (default: 8) @@ -60,6 +61,7 @@ DETECTOR="harris" KERNEL_SIZES="3,3,3" NUM_LAYERS=3 MIP_LEVEL=0 +GRAYSCALE_LOSS=false FULL_IMAGE_MODE=false IMAGE_SIZE=256 @@ -143,6 +145,10 @@ while [[ $# -gt 0 ]]; do MIP_LEVEL="$2" shift 2 ;; + --grayscale-loss) + GRAYSCALE_LOSS=true + shift + ;; --patch-size) if [ -z "$2" ]; then echo "Error: --patch-size requires a number argument" @@ -267,6 +273,13 @@ fi if [ "$VALIDATE_ONLY" = false ]; then # Step 1: Train model echo "[1/4] Training CNN v2 model..." + +# Build optional flags +OPTIONAL_FLAGS="" +if [ "$GRAYSCALE_LOSS" = true ]; then + OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss" +fi + python3 training/train_cnn_v2.py \ --input "$INPUT_DIR" \ --target "$TARGET_DIR" \ @@ -277,7 +290,8 @@ python3 training/train_cnn_v2.py \ --epochs $EPOCHS \ --batch-size $BATCH_SIZE \ --checkpoint-dir "$CHECKPOINT_DIR" \ - --checkpoint-every $CHECKPOINT_EVERY + --checkpoint-every $CHECKPOINT_EVERY \ + $OPTIONAL_FLAGS if [ $? -ne 0 ]; then echo "Error: Training failed" |
