From a7340d378909cadbfd72dbd1f5b756f907c2a3e0 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 17:46:09 +0100 Subject: CNN v2 training: Add --grayscale-loss option for luminance-based loss computation Add option to compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of full RGBA channels. Useful for training models that prioritize luminance accuracy over color accuracy. Changes: - training/train_cnn_v2.py: Add --grayscale-loss flag and grayscale conversion in loss computation - scripts/train_cnn_v2_full.sh: Add --grayscale-loss parameter support - doc/CNN_V2.md: Document grayscale loss in training configuration and checkpoint format - doc/HOWTO.md: Add usage examples for --grayscale-loss flag Co-Authored-By: Claude Sonnet 4.5 --- scripts/train_cnn_v2_full.sh | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'scripts/train_cnn_v2_full.sh') diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index 8b09191..e444f20 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -16,6 +16,7 @@ # --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3) # --num-layers N Number of layers (default: 3) # --mip-level N Mip level for p0-p3 features: 0-3 (default: 0) +# --grayscale-loss Compute loss on grayscale instead of RGBA # # PATCH PARAMETERS: # --patch-size N Patch size (default: 8) @@ -60,6 +61,7 @@ DETECTOR="harris" KERNEL_SIZES="3,3,3" NUM_LAYERS=3 MIP_LEVEL=0 +GRAYSCALE_LOSS=false FULL_IMAGE_MODE=false IMAGE_SIZE=256 @@ -143,6 +145,10 @@ while [[ $# -gt 0 ]]; do MIP_LEVEL="$2" shift 2 ;; + --grayscale-loss) + GRAYSCALE_LOSS=true + shift + ;; --patch-size) if [ -z "$2" ]; then echo "Error: --patch-size requires a number argument" @@ -267,6 +273,13 @@ fi if [ "$VALIDATE_ONLY" = false ]; then # Step 1: Train model echo "[1/4] Training CNN v2 model..." + +# Build optional flags +OPTIONAL_FLAGS="" +if [ "$GRAYSCALE_LOSS" = true ]; then + OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss" +fi + python3 training/train_cnn_v2.py \ --input "$INPUT_DIR" \ --target "$TARGET_DIR" \ @@ -277,7 +290,8 @@ python3 training/train_cnn_v2.py \ --epochs $EPOCHS \ --batch-size $BATCH_SIZE \ --checkpoint-dir "$CHECKPOINT_DIR" \ - --checkpoint-every $CHECKPOINT_EVERY + --checkpoint-every $CHECKPOINT_EVERY \ + $OPTIONAL_FLAGS if [ $? -ne 0 ]; then echo "Error: Training failed" -- cgit v1.2.3