summaryrefslogtreecommitdiff
path: root/cnn_v2/scripts/train_cnn_v2_full.sh
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v2/scripts/train_cnn_v2_full.sh')
-rwxr-xr-xcnn_v2/scripts/train_cnn_v2_full.sh428
1 files changed, 428 insertions, 0 deletions
diff --git a/cnn_v2/scripts/train_cnn_v2_full.sh b/cnn_v2/scripts/train_cnn_v2_full.sh
new file mode 100755
index 0000000..a21c1ac
--- /dev/null
+++ b/cnn_v2/scripts/train_cnn_v2_full.sh
@@ -0,0 +1,428 @@
+#!/bin/bash
+# Complete CNN v2 Training Pipeline
+# Train → Export → Build → Validate
+# Usage: ./train_cnn_v2_full.sh [OPTIONS]
+#
+# MODES:
+# (none) Run complete pipeline: train → export → build → validate
+# --validate Validate only (skip training, use existing weights)
+# --validate CHECKPOINT Validate with specific checkpoint file
+# --export-only CHECKPOINT Export weights only (skip training, build, validation)
+#
+# TRAINING PARAMETERS:
+# --epochs N Training epochs (default: 200)
+# --batch-size N Batch size (default: 16)
+# --lr FLOAT Learning rate (default: 1e-3)
+# --checkpoint-every N Checkpoint interval (default: 50)
+# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
+# --num-layers N Number of layers (default: 3)
+# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0)
+# --grayscale-loss Compute loss on grayscale instead of RGBA
+#
+# PATCH PARAMETERS:
+# --patch-size N Patch size (default: 8)
+# --patches-per-image N Patches per image (default: 256)
+# --detector TYPE Detector: harris|fast|shi-tomasi|gradient (default: harris)
+# --full-image Use full-image training (disables patch mode)
+# --image-size N Image size for full-image mode (default: 256)
+#
+# DIRECTORIES:
+# --input DIR Input directory (default: training/input)
+# --target DIR Target directory (default: training/target_1)
+# --checkpoint-dir DIR Checkpoint directory (default: checkpoints)
+# --validation-dir DIR Validation directory (default: validation_results)
+#
+# OUTPUT:
+# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin)
+#
+# OTHER:
+# --help Show this help message
+#
+# Examples:
+# ./train_cnn_v2_full.sh
+# ./train_cnn_v2_full.sh --epochs 500 --batch-size 32
+# ./train_cnn_v2_full.sh --validate
+# ./train_cnn_v2_full.sh --validate checkpoints/checkpoint_epoch_50.pth
+# ./train_cnn_v2_full.sh --export-only checkpoints/checkpoint_epoch_100.pth
+# ./train_cnn_v2_full.sh --mip-level 1 --kernel-sizes 3,5,3
+
+set -e
+
+PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$PROJECT_ROOT"
+
+# Helper functions
+export_weights() {
+ python3 "$SCRIPT_DIR/../training/export_cnn_v2_weights.py" "$1" --output-weights "$2" --quiet
+}
+
+find_latest_checkpoint() {
+ ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth 2>/dev/null | head -1
+}
+
+build_target() {
+ cmake --build build -j4 --target "$1" > /dev/null 2>&1
+}
+
+# Path resolution for running from any directory
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
+
+# Default configuration
+INPUT_DIR="training/input"
+TARGET_DIR="training/target_1"
+CHECKPOINT_DIR="checkpoints"
+VALIDATION_DIR="validation_results"
+EPOCHS=200
+CHECKPOINT_EVERY=50
+BATCH_SIZE=16
+LEARNING_RATE=1e-3
+PATCH_SIZE=8
+PATCHES_PER_IMAGE=256
+DETECTOR="harris"
+KERNEL_SIZES="3,3,3"
+NUM_LAYERS=3
+MIP_LEVEL=0
+GRAYSCALE_LOSS=false
+FULL_IMAGE_MODE=false
+IMAGE_SIZE=256
+OUTPUT_WEIGHTS="${PROJECT_ROOT}/workspaces/main/weights/cnn_v2_weights.bin"
+
+# Parse arguments
+VALIDATE_ONLY=false
+VALIDATE_CHECKPOINT=""
+EXPORT_ONLY=false
+EXPORT_CHECKPOINT=""
+
+if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
+ head -47 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//'
+ exit 0
+fi
+
+# Parse all arguments
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --export-only)
+ EXPORT_ONLY=true
+ if [ -z "$2" ]; then
+ echo "Error: --export-only requires a checkpoint file argument"
+ exit 1
+ fi
+ EXPORT_CHECKPOINT="$2"
+ shift 2
+ ;;
+ --validate)
+ VALIDATE_ONLY=true
+ if [ -n "$2" ] && [[ ! "$2" =~ ^-- ]]; then
+ VALIDATE_CHECKPOINT="$2"
+ shift 2
+ else
+ shift
+ fi
+ ;;
+ --epochs)
+ if [ -z "$2" ]; then
+ echo "Error: --epochs requires a number argument"
+ exit 1
+ fi
+ EPOCHS="$2"
+ shift 2
+ ;;
+ --batch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --batch-size requires a number argument"
+ exit 1
+ fi
+ BATCH_SIZE="$2"
+ shift 2
+ ;;
+ --checkpoint-every)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-every requires a number argument"
+ exit 1
+ fi
+ CHECKPOINT_EVERY="$2"
+ shift 2
+ ;;
+ --kernel-sizes)
+ if [ -z "$2" ]; then
+ echo "Error: --kernel-sizes requires a comma-separated list"
+ exit 1
+ fi
+ KERNEL_SIZES="$2"
+ shift 2
+ ;;
+ --num-layers)
+ if [ -z "$2" ]; then
+ echo "Error: --num-layers requires a number argument"
+ exit 1
+ fi
+ NUM_LAYERS="$2"
+ shift 2
+ ;;
+ --mip-level)
+ if [ -z "$2" ]; then
+ echo "Error: --mip-level requires a level argument (0-3)"
+ exit 1
+ fi
+ MIP_LEVEL="$2"
+ shift 2
+ ;;
+ --grayscale-loss)
+ GRAYSCALE_LOSS=true
+ shift
+ ;;
+ --lr)
+ if [ -z "$2" ]; then
+ echo "Error: --lr requires a float argument"
+ exit 1
+ fi
+ LEARNING_RATE="$2"
+ shift 2
+ ;;
+ --patch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --patch-size requires a number argument"
+ exit 1
+ fi
+ PATCH_SIZE="$2"
+ shift 2
+ ;;
+ --patches-per-image)
+ if [ -z "$2" ]; then
+ echo "Error: --patches-per-image requires a number argument"
+ exit 1
+ fi
+ PATCHES_PER_IMAGE="$2"
+ shift 2
+ ;;
+ --detector)
+ if [ -z "$2" ]; then
+ echo "Error: --detector requires a type argument"
+ exit 1
+ fi
+ DETECTOR="$2"
+ shift 2
+ ;;
+ --full-image)
+ FULL_IMAGE_MODE=true
+ shift
+ ;;
+ --image-size)
+ if [ -z "$2" ]; then
+ echo "Error: --image-size requires a number argument"
+ exit 1
+ fi
+ IMAGE_SIZE="$2"
+ shift 2
+ ;;
+ --input)
+ if [ -z "$2" ]; then
+ echo "Error: --input requires a directory argument"
+ exit 1
+ fi
+ INPUT_DIR="$2"
+ shift 2
+ ;;
+ --target)
+ if [ -z "$2" ]; then
+ echo "Error: --target requires a directory argument"
+ exit 1
+ fi
+ TARGET_DIR="$2"
+ shift 2
+ ;;
+ --checkpoint-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-dir requires a directory argument"
+ exit 1
+ fi
+ CHECKPOINT_DIR="$2"
+ shift 2
+ ;;
+ --validation-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --validation-dir requires a directory argument"
+ exit 1
+ fi
+ VALIDATION_DIR="$2"
+ shift 2
+ ;;
+ --output-weights)
+ if [ -z "$2" ]; then
+ echo "Error: --output-weights requires a file path argument"
+ exit 1
+ fi
+ OUTPUT_WEIGHTS="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
+# Build training arguments
+if [ "$FULL_IMAGE_MODE" = true ]; then
+ TRAINING_MODE_ARGS="--full-image --image-size $IMAGE_SIZE"
+else
+ TRAINING_MODE_ARGS="--patch-size $PATCH_SIZE --patches-per-image $PATCHES_PER_IMAGE --detector $DETECTOR"
+fi
+
+# Handle export-only mode
+if [ "$EXPORT_ONLY" = true ]; then
+ echo "=== CNN v2 Export Weights Only ==="
+ echo "Checkpoint: $EXPORT_CHECKPOINT"
+ echo ""
+
+ if [ ! -f "$EXPORT_CHECKPOINT" ]; then
+ echo "Error: Checkpoint file not found: $EXPORT_CHECKPOINT"
+ exit 1
+ fi
+
+ export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
+ echo "Error: Export failed"
+ exit 1
+ }
+
+ echo ""
+ echo "=== Export Complete ==="
+ echo "Output: $OUTPUT_WEIGHTS"
+ exit 0
+fi
+
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== CNN v2 Validation Only ==="
+ echo "Skipping training, using existing weights"
+ echo ""
+else
+ echo "=== CNN v2 Complete Training Pipeline ==="
+ echo "Input: $INPUT_DIR"
+ echo "Target: $TARGET_DIR"
+ echo "Epochs: $EPOCHS"
+ echo "Checkpoint interval: $CHECKPOINT_EVERY"
+ echo "Mip level: $MIP_LEVEL (p0-p3 features)"
+ echo ""
+fi
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ # Step 1: Train model
+ echo "[1/4] Training CNN v2 model..."
+
+python3 "$SCRIPT_DIR/../training/train_cnn_v2.py" \
+ --input "$INPUT_DIR" \
+ --target "$TARGET_DIR" \
+ $TRAINING_MODE_ARGS \
+ --kernel-sizes "$KERNEL_SIZES" \
+ --num-layers "$NUM_LAYERS" \
+ --mip-level "$MIP_LEVEL" \
+ --epochs "$EPOCHS" \
+ --batch-size "$BATCH_SIZE" \
+ --lr "$LEARNING_RATE" \
+ --checkpoint-dir "$CHECKPOINT_DIR" \
+ --checkpoint-every "$CHECKPOINT_EVERY" \
+ $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
+
+if [ $? -ne 0 ]; then
+ echo "Error: Training failed"
+ exit 1
+fi
+
+echo ""
+echo "Training complete!"
+echo ""
+
+# Step 2: Export final checkpoint to shaders
+FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth"
+
+if [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Warning: Final checkpoint not found, using latest available..."
+ FINAL_CHECKPOINT=$(find_latest_checkpoint)
+fi
+
+if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Error: No checkpoint found in $CHECKPOINT_DIR"
+ exit 1
+fi
+
+echo "[2/4] Exporting final checkpoint to binary weights..."
+echo "Checkpoint: $FINAL_CHECKPOINT"
+export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
+ echo "Error: Shader export failed"
+ exit 1
+}
+
+echo ""
+fi # End of training/export section
+
+# Determine which checkpoint to use
+if [ "$VALIDATE_ONLY" = true ]; then
+ FINAL_CHECKPOINT="${VALIDATE_CHECKPOINT:-$(find_latest_checkpoint)}"
+ echo "Using checkpoint: $FINAL_CHECKPOINT"
+ echo ""
+fi
+
+# Step 3: Rebuild with new shaders
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo "[3/4] Rebuilding demo with new shaders..."
+ build_target demo64k || {
+ echo "Error: Build failed"
+ exit 1
+ }
+ echo " → Build complete"
+ echo ""
+fi
+
+# Step 4: Visual assessment - process final checkpoint only
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "Validation on all input images (using existing weights)..."
+else
+ echo "[4/4] Visual assessment on all input images..."
+fi
+
+mkdir -p "$VALIDATION_DIR"
+echo " Using checkpoint: $FINAL_CHECKPOINT"
+
+# Export weights for validation mode (already exported in step 2 for training mode)
+if [ "$VALIDATE_ONLY" = true ]; then
+ export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1
+fi
+
+# Build cnn_test
+build_target cnn_test
+
+# Process all input images
+echo -n " Processing images: "
+for input_image in "$INPUT_DIR"/*.png; do
+ basename=$(basename "$input_image" .png)
+ echo -n "$basename "
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" > /dev/null 2>&1
+done
+echo "✓"
+
+# Build demo only if not in validate mode
+[ "$VALIDATE_ONLY" = false ] && build_target demo64k
+
+echo ""
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== Validation Complete ==="
+else
+ echo "=== Training Pipeline Complete ==="
+fi
+echo ""
+echo "Results:"
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo " - Checkpoints: $CHECKPOINT_DIR"
+ echo " - Final weights: $OUTPUT_WEIGHTS"
+fi
+echo " - Validation outputs: $VALIDATION_DIR"
+echo ""
+echo "Opening results directory..."
+open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo ""
+ echo "Run demo to see final result:"
+ echo " ./build/demo64k"
+fi