summaryrefslogtreecommitdiff
path: root/scripts/train_cnn_v2_full.sh
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 17:46:09 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 17:46:09 +0100
commita7340d378909cadbfd72dbd1f5b756f907c2a3e0 (patch)
tree60a34dd084d746f0c6dad50b0d5cc7f20bc0c409 /scripts/train_cnn_v2_full.sh
parentf6b3ea72a03850654b69986bc82bb249aaabe2e3 (diff)
CNN v2 training: Add --grayscale-loss option for luminance-based loss computation
Add option to compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of full RGBA channels. Useful for training models that prioritize luminance accuracy over color accuracy. Changes: - training/train_cnn_v2.py: Add --grayscale-loss flag and grayscale conversion in loss computation - scripts/train_cnn_v2_full.sh: Add --grayscale-loss parameter support - doc/CNN_V2.md: Document grayscale loss in training configuration and checkpoint format - doc/HOWTO.md: Add usage examples for --grayscale-loss flag Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
-rwxr-xr-xscripts/train_cnn_v2_full.sh16
1 files changed, 15 insertions, 1 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 8b09191..e444f20 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -16,6 +16,7 @@
# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
# --num-layers N Number of layers (default: 3)
# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0)
+# --grayscale-loss Compute loss on grayscale instead of RGBA
#
# PATCH PARAMETERS:
# --patch-size N Patch size (default: 8)
@@ -60,6 +61,7 @@ DETECTOR="harris"
KERNEL_SIZES="3,3,3"
NUM_LAYERS=3
MIP_LEVEL=0
+GRAYSCALE_LOSS=false
FULL_IMAGE_MODE=false
IMAGE_SIZE=256
@@ -143,6 +145,10 @@ while [[ $# -gt 0 ]]; do
MIP_LEVEL="$2"
shift 2
;;
+ --grayscale-loss)
+ GRAYSCALE_LOSS=true
+ shift
+ ;;
--patch-size)
if [ -z "$2" ]; then
echo "Error: --patch-size requires a number argument"
@@ -267,6 +273,13 @@ fi
if [ "$VALIDATE_ONLY" = false ]; then
# Step 1: Train model
echo "[1/4] Training CNN v2 model..."
+
+# Build optional flags
+OPTIONAL_FLAGS=""
+if [ "$GRAYSCALE_LOSS" = true ]; then
+ OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss"
+fi
+
python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
@@ -277,7 +290,8 @@ python3 training/train_cnn_v2.py \
--epochs $EPOCHS \
--batch-size $BATCH_SIZE \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY
+ --checkpoint-every $CHECKPOINT_EVERY \
+ $OPTIONAL_FLAGS
if [ $? -ne 0 ]; then
echo "Error: Training failed"