diff options
Diffstat (limited to 'cnn_v2/scripts/train_cnn_v2_full.sh')
| -rwxr-xr-x | cnn_v2/scripts/train_cnn_v2_full.sh | 428 |
1 files changed, 428 insertions, 0 deletions
diff --git a/cnn_v2/scripts/train_cnn_v2_full.sh b/cnn_v2/scripts/train_cnn_v2_full.sh new file mode 100755 index 0000000..a21c1ac --- /dev/null +++ b/cnn_v2/scripts/train_cnn_v2_full.sh @@ -0,0 +1,428 @@ +#!/bin/bash +# Complete CNN v2 Training Pipeline +# Train → Export → Build → Validate +# Usage: ./train_cnn_v2_full.sh [OPTIONS] +# +# MODES: +# (none) Run complete pipeline: train → export → build → validate +# --validate Validate only (skip training, use existing weights) +# --validate CHECKPOINT Validate with specific checkpoint file +# --export-only CHECKPOINT Export weights only (skip training, build, validation) +# +# TRAINING PARAMETERS: +# --epochs N Training epochs (default: 200) +# --batch-size N Batch size (default: 16) +# --lr FLOAT Learning rate (default: 1e-3) +# --checkpoint-every N Checkpoint interval (default: 50) +# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3) +# --num-layers N Number of layers (default: 3) +# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0) +# --grayscale-loss Compute loss on grayscale instead of RGBA +# +# PATCH PARAMETERS: +# --patch-size N Patch size (default: 8) +# --patches-per-image N Patches per image (default: 256) +# --detector TYPE Detector: harris|fast|shi-tomasi|gradient (default: harris) +# --full-image Use full-image training (disables patch mode) +# --image-size N Image size for full-image mode (default: 256) +# +# DIRECTORIES: +# --input DIR Input directory (default: training/input) +# --target DIR Target directory (default: training/target_1) +# --checkpoint-dir DIR Checkpoint directory (default: checkpoints) +# --validation-dir DIR Validation directory (default: validation_results) +# +# OUTPUT: +# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin) +# +# OTHER: +# --help Show this help message +# +# Examples: +# ./train_cnn_v2_full.sh +# ./train_cnn_v2_full.sh --epochs 500 --batch-size 32 +# ./train_cnn_v2_full.sh --validate +# ./train_cnn_v2_full.sh --validate checkpoints/checkpoint_epoch_50.pth +# ./train_cnn_v2_full.sh --export-only checkpoints/checkpoint_epoch_100.pth +# ./train_cnn_v2_full.sh --mip-level 1 --kernel-sizes 3,5,3 + +set -e + +PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$PROJECT_ROOT" + +# Helper functions +export_weights() { + python3 "$SCRIPT_DIR/../training/export_cnn_v2_weights.py" "$1" --output-weights "$2" --quiet +} + +find_latest_checkpoint() { + ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth 2>/dev/null | head -1 +} + +build_target() { + cmake --build build -j4 --target "$1" > /dev/null 2>&1 +} + +# Path resolution for running from any directory +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +# Default configuration +INPUT_DIR="training/input" +TARGET_DIR="training/target_1" +CHECKPOINT_DIR="checkpoints" +VALIDATION_DIR="validation_results" +EPOCHS=200 +CHECKPOINT_EVERY=50 +BATCH_SIZE=16 +LEARNING_RATE=1e-3 +PATCH_SIZE=8 +PATCHES_PER_IMAGE=256 +DETECTOR="harris" +KERNEL_SIZES="3,3,3" +NUM_LAYERS=3 +MIP_LEVEL=0 +GRAYSCALE_LOSS=false +FULL_IMAGE_MODE=false +IMAGE_SIZE=256 +OUTPUT_WEIGHTS="${PROJECT_ROOT}/workspaces/main/weights/cnn_v2_weights.bin" + +# Parse arguments +VALIDATE_ONLY=false +VALIDATE_CHECKPOINT="" +EXPORT_ONLY=false +EXPORT_CHECKPOINT="" + +if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then + head -47 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//' + exit 0 +fi + +# Parse all arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --export-only) + EXPORT_ONLY=true + if [ -z "$2" ]; then + echo "Error: --export-only requires a checkpoint file argument" + exit 1 + fi + EXPORT_CHECKPOINT="$2" + shift 2 + ;; + --validate) + VALIDATE_ONLY=true + if [ -n "$2" ] && [[ ! "$2" =~ ^-- ]]; then + VALIDATE_CHECKPOINT="$2" + shift 2 + else + shift + fi + ;; + --epochs) + if [ -z "$2" ]; then + echo "Error: --epochs requires a number argument" + exit 1 + fi + EPOCHS="$2" + shift 2 + ;; + --batch-size) + if [ -z "$2" ]; then + echo "Error: --batch-size requires a number argument" + exit 1 + fi + BATCH_SIZE="$2" + shift 2 + ;; + --checkpoint-every) + if [ -z "$2" ]; then + echo "Error: --checkpoint-every requires a number argument" + exit 1 + fi + CHECKPOINT_EVERY="$2" + shift 2 + ;; + --kernel-sizes) + if [ -z "$2" ]; then + echo "Error: --kernel-sizes requires a comma-separated list" + exit 1 + fi + KERNEL_SIZES="$2" + shift 2 + ;; + --num-layers) + if [ -z "$2" ]; then + echo "Error: --num-layers requires a number argument" + exit 1 + fi + NUM_LAYERS="$2" + shift 2 + ;; + --mip-level) + if [ -z "$2" ]; then + echo "Error: --mip-level requires a level argument (0-3)" + exit 1 + fi + MIP_LEVEL="$2" + shift 2 + ;; + --grayscale-loss) + GRAYSCALE_LOSS=true + shift + ;; + --lr) + if [ -z "$2" ]; then + echo "Error: --lr requires a float argument" + exit 1 + fi + LEARNING_RATE="$2" + shift 2 + ;; + --patch-size) + if [ -z "$2" ]; then + echo "Error: --patch-size requires a number argument" + exit 1 + fi + PATCH_SIZE="$2" + shift 2 + ;; + --patches-per-image) + if [ -z "$2" ]; then + echo "Error: --patches-per-image requires a number argument" + exit 1 + fi + PATCHES_PER_IMAGE="$2" + shift 2 + ;; + --detector) + if [ -z "$2" ]; then + echo "Error: --detector requires a type argument" + exit 1 + fi + DETECTOR="$2" + shift 2 + ;; + --full-image) + FULL_IMAGE_MODE=true + shift + ;; + --image-size) + if [ -z "$2" ]; then + echo "Error: --image-size requires a number argument" + exit 1 + fi + IMAGE_SIZE="$2" + shift 2 + ;; + --input) + if [ -z "$2" ]; then + echo "Error: --input requires a directory argument" + exit 1 + fi + INPUT_DIR="$2" + shift 2 + ;; + --target) + if [ -z "$2" ]; then + echo "Error: --target requires a directory argument" + exit 1 + fi + TARGET_DIR="$2" + shift 2 + ;; + --checkpoint-dir) + if [ -z "$2" ]; then + echo "Error: --checkpoint-dir requires a directory argument" + exit 1 + fi + CHECKPOINT_DIR="$2" + shift 2 + ;; + --validation-dir) + if [ -z "$2" ]; then + echo "Error: --validation-dir requires a directory argument" + exit 1 + fi + VALIDATION_DIR="$2" + shift 2 + ;; + --output-weights) + if [ -z "$2" ]; then + echo "Error: --output-weights requires a file path argument" + exit 1 + fi + OUTPUT_WEIGHTS="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Build training arguments +if [ "$FULL_IMAGE_MODE" = true ]; then + TRAINING_MODE_ARGS="--full-image --image-size $IMAGE_SIZE" +else + TRAINING_MODE_ARGS="--patch-size $PATCH_SIZE --patches-per-image $PATCHES_PER_IMAGE --detector $DETECTOR" +fi + +# Handle export-only mode +if [ "$EXPORT_ONLY" = true ]; then + echo "=== CNN v2 Export Weights Only ===" + echo "Checkpoint: $EXPORT_CHECKPOINT" + echo "" + + if [ ! -f "$EXPORT_CHECKPOINT" ]; then + echo "Error: Checkpoint file not found: $EXPORT_CHECKPOINT" + exit 1 + fi + + export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || { + echo "Error: Export failed" + exit 1 + } + + echo "" + echo "=== Export Complete ===" + echo "Output: $OUTPUT_WEIGHTS" + exit 0 +fi + +if [ "$VALIDATE_ONLY" = true ]; then + echo "=== CNN v2 Validation Only ===" + echo "Skipping training, using existing weights" + echo "" +else + echo "=== CNN v2 Complete Training Pipeline ===" + echo "Input: $INPUT_DIR" + echo "Target: $TARGET_DIR" + echo "Epochs: $EPOCHS" + echo "Checkpoint interval: $CHECKPOINT_EVERY" + echo "Mip level: $MIP_LEVEL (p0-p3 features)" + echo "" +fi + +if [ "$VALIDATE_ONLY" = false ]; then + # Step 1: Train model + echo "[1/4] Training CNN v2 model..." + +python3 "$SCRIPT_DIR/../training/train_cnn_v2.py" \ + --input "$INPUT_DIR" \ + --target "$TARGET_DIR" \ + $TRAINING_MODE_ARGS \ + --kernel-sizes "$KERNEL_SIZES" \ + --num-layers "$NUM_LAYERS" \ + --mip-level "$MIP_LEVEL" \ + --epochs "$EPOCHS" \ + --batch-size "$BATCH_SIZE" \ + --lr "$LEARNING_RATE" \ + --checkpoint-dir "$CHECKPOINT_DIR" \ + --checkpoint-every "$CHECKPOINT_EVERY" \ + $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss") + +if [ $? -ne 0 ]; then + echo "Error: Training failed" + exit 1 +fi + +echo "" +echo "Training complete!" +echo "" + +# Step 2: Export final checkpoint to shaders +FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth" + +if [ ! -f "$FINAL_CHECKPOINT" ]; then + echo "Warning: Final checkpoint not found, using latest available..." + FINAL_CHECKPOINT=$(find_latest_checkpoint) +fi + +if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then + echo "Error: No checkpoint found in $CHECKPOINT_DIR" + exit 1 +fi + +echo "[2/4] Exporting final checkpoint to binary weights..." +echo "Checkpoint: $FINAL_CHECKPOINT" +export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || { + echo "Error: Shader export failed" + exit 1 +} + +echo "" +fi # End of training/export section + +# Determine which checkpoint to use +if [ "$VALIDATE_ONLY" = true ]; then + FINAL_CHECKPOINT="${VALIDATE_CHECKPOINT:-$(find_latest_checkpoint)}" + echo "Using checkpoint: $FINAL_CHECKPOINT" + echo "" +fi + +# Step 3: Rebuild with new shaders +if [ "$VALIDATE_ONLY" = false ]; then + echo "[3/4] Rebuilding demo with new shaders..." + build_target demo64k || { + echo "Error: Build failed" + exit 1 + } + echo " → Build complete" + echo "" +fi + +# Step 4: Visual assessment - process final checkpoint only +if [ "$VALIDATE_ONLY" = true ]; then + echo "Validation on all input images (using existing weights)..." +else + echo "[4/4] Visual assessment on all input images..." +fi + +mkdir -p "$VALIDATION_DIR" +echo " Using checkpoint: $FINAL_CHECKPOINT" + +# Export weights for validation mode (already exported in step 2 for training mode) +if [ "$VALIDATE_ONLY" = true ]; then + export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1 +fi + +# Build cnn_test +build_target cnn_test + +# Process all input images +echo -n " Processing images: " +for input_image in "$INPUT_DIR"/*.png; do + basename=$(basename "$input_image" .png) + echo -n "$basename " + build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" > /dev/null 2>&1 +done +echo "✓" + +# Build demo only if not in validate mode +[ "$VALIDATE_ONLY" = false ] && build_target demo64k + +echo "" +if [ "$VALIDATE_ONLY" = true ]; then + echo "=== Validation Complete ===" +else + echo "=== Training Pipeline Complete ===" +fi +echo "" +echo "Results:" +if [ "$VALIDATE_ONLY" = false ]; then + echo " - Checkpoints: $CHECKPOINT_DIR" + echo " - Final weights: $OUTPUT_WEIGHTS" +fi +echo " - Validation outputs: $VALIDATION_DIR" +echo "" +echo "Opening results directory..." +open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true + +if [ "$VALIDATE_ONLY" = false ]; then + echo "" + echo "Run demo to see final result:" + echo " ./build/demo64k" +fi |
