summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/train_cnn_v2_full.sh197
-rwxr-xr-xscripts/validate_cnn_v2.sh60
2 files changed, 257 insertions, 0 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
new file mode 100755
index 0000000..fc9355a
--- /dev/null
+++ b/scripts/train_cnn_v2_full.sh
@@ -0,0 +1,197 @@
+#!/bin/bash
+# Complete CNN v2 Training Pipeline
+# Train → Export → Build → Validate
+# Usage: ./train_cnn_v2_full.sh [OPTIONS]
+#
+# OPTIONS:
+# (none) Run complete pipeline: train → export → build → validate
+# --validate Validate only (skip training, use existing weights)
+# --validate CHECKPOINT Validate with specific checkpoint file
+# --help Show this help message
+#
+# Examples:
+# ./train_cnn_v2_full.sh
+# ./train_cnn_v2_full.sh --validate
+# ./train_cnn_v2_full.sh --validate checkpoints/checkpoint_epoch_50.pth
+
+set -e
+
+PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$PROJECT_ROOT"
+
+# Parse arguments
+VALIDATE_ONLY=false
+VALIDATE_CHECKPOINT=""
+
+if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
+ head -20 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//'
+ exit 0
+fi
+
+if [ "$1" = "--validate" ]; then
+ VALIDATE_ONLY=true
+ if [ -n "$2" ]; then
+ VALIDATE_CHECKPOINT="$2"
+ fi
+fi
+
+# Configuration
+INPUT_DIR="training/input"
+TARGET_DIR="training/target_2"
+CHECKPOINT_DIR="checkpoints"
+VALIDATION_DIR="validation_results"
+EPOCHS=100
+CHECKPOINT_EVERY=5
+BATCH_SIZE=16
+
+# Patch-based training (default)
+PATCH_SIZE=32
+PATCHES_PER_IMAGE=64
+DETECTOR="harris"
+
+# Full-image training (alternative - uncomment to use)
+# FULL_IMAGE="--full-image"
+# IMAGE_SIZE=256
+
+KERNEL_SIZES="3 3 3"
+CHANNELS="8 4 4"
+
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== CNN v2 Validation Only ==="
+ echo "Skipping training, using existing weights"
+ echo ""
+else
+ echo "=== CNN v2 Complete Training Pipeline ==="
+ echo "Input: $INPUT_DIR"
+ echo "Target: $TARGET_DIR"
+ echo "Epochs: $EPOCHS"
+ echo "Checkpoint interval: $CHECKPOINT_EVERY"
+ echo ""
+fi
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ # Step 1: Train model
+ echo "[1/4] Training CNN v2 model..."
+python3 training/train_cnn_v2.py \
+ --input "$INPUT_DIR" \
+ --target "$TARGET_DIR" \
+ --patch-size $PATCH_SIZE \
+ --patches-per-image $PATCHES_PER_IMAGE \
+ --detector $DETECTOR \
+ --kernel-sizes $KERNEL_SIZES \
+ --channels $CHANNELS \
+ --epochs $EPOCHS \
+ --batch-size $BATCH_SIZE \
+ --checkpoint-dir "$CHECKPOINT_DIR" \
+ --checkpoint-every $CHECKPOINT_EVERY \
+ $FULL_IMAGE
+
+if [ $? -ne 0 ]; then
+ echo "Error: Training failed"
+ exit 1
+fi
+
+echo ""
+echo "Training complete!"
+echo ""
+
+# Step 2: Export final checkpoint to shaders
+FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth"
+
+if [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Warning: Final checkpoint not found, using latest available..."
+ FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1)
+fi
+
+echo "[2/4] Exporting final checkpoint to WGSL shaders..."
+echo "Checkpoint: $FINAL_CHECKPOINT"
+python3 training/export_cnn_v2_shader.py "$FINAL_CHECKPOINT" \
+ --output-dir workspaces/main/shaders
+
+if [ $? -ne 0 ]; then
+ echo "Error: Shader export failed"
+ exit 1
+fi
+
+echo ""
+fi # End of training/export section
+
+# Determine which checkpoint to use
+if [ "$VALIDATE_ONLY" = true ]; then
+ if [ -n "$VALIDATE_CHECKPOINT" ]; then
+ FINAL_CHECKPOINT="$VALIDATE_CHECKPOINT"
+ else
+ # Use latest checkpoint
+ FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1)
+ fi
+ echo "Using checkpoint: $FINAL_CHECKPOINT"
+ echo ""
+fi
+
+# Step 3: Rebuild with new shaders
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo "[3/4] Rebuilding demo with new shaders..."
+ cmake --build build -j4 --target demo64k > /dev/null 2>&1
+
+ if [ $? -ne 0 ]; then
+ echo "Error: Build failed"
+ exit 1
+ fi
+
+ echo " → Build complete"
+ echo ""
+fi
+
+# Step 4: Visual assessment - process final checkpoint only
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "Validation on all input images (using existing weights)..."
+else
+ echo "[4/4] Visual assessment on all input images..."
+fi
+
+mkdir -p "$VALIDATION_DIR"
+echo " Using checkpoint: $FINAL_CHECKPOINT"
+
+# Export weights only if not in validate mode
+if [ "$VALIDATE_ONLY" = false ]; then
+ python3 training/export_cnn_v2_weights.py "$FINAL_CHECKPOINT" \
+ --output-weights workspaces/main/cnn_v2_weights.bin > /dev/null 2>&1
+fi
+
+# Build cnn_test
+cmake --build build -j4 --target cnn_test > /dev/null 2>&1
+
+# Process all input images
+for input_image in "$INPUT_DIR"/*.png; do
+ basename=$(basename "$input_image" .png)
+ echo " Processing $basename..."
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" 2>/dev/null
+done
+
+# Build demo only if not in validate mode
+if [ "$VALIDATE_ONLY" = false ]; then
+ cmake --build build -j4 --target demo64k > /dev/null 2>&1
+fi
+
+echo ""
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== Validation Complete ==="
+else
+ echo "=== Training Pipeline Complete ==="
+fi
+echo ""
+echo "Results:"
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo " - Checkpoints: $CHECKPOINT_DIR"
+ echo " - Final weights: workspaces/main/cnn_v2_weights.bin"
+fi
+echo " - Validation outputs: $VALIDATION_DIR"
+echo ""
+echo "Opening results directory..."
+open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo ""
+ echo "Run demo to see final result:"
+ echo " ./build/demo64k"
+fi
diff --git a/scripts/validate_cnn_v2.sh b/scripts/validate_cnn_v2.sh
new file mode 100755
index 0000000..06a4e01
--- /dev/null
+++ b/scripts/validate_cnn_v2.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+# CNN v2 Validation - End-to-end pipeline
+
+set -e
+PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+BUILD_DIR="$PROJECT_ROOT/build"
+WORKSPACE="main"
+
+usage() {
+ echo "Usage: $0 <checkpoint.pth> [options]"
+ echo "Options:"
+ echo " -i DIR Test images (default: training/validation)"
+ echo " -o DIR Output (default: validation_results)"
+ echo " --skip-build Skip rebuild"
+ exit 1
+}
+
+[ $# -eq 0 ] && usage
+CHECKPOINT="$1"
+shift
+
+TEST_IMAGES="$PROJECT_ROOT/training/validation"
+OUTPUT="$PROJECT_ROOT/validation_results"
+SKIP_BUILD=false
+
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ -i) TEST_IMAGES="$2"; shift 2 ;;
+ -o) OUTPUT="$2"; shift 2 ;;
+ --skip-build) SKIP_BUILD=true; shift ;;
+ -h) usage ;;
+ *) usage ;;
+ esac
+done
+
+echo "=== CNN v2 Validation ==="
+echo "Checkpoint: $CHECKPOINT"
+
+# Export
+echo "[1/3] Exporting shaders..."
+python3 "$PROJECT_ROOT/training/export_cnn_v2_shader.py" "$CHECKPOINT" \
+ --output-dir "$PROJECT_ROOT/workspaces/$WORKSPACE/shaders"
+
+# Build
+if [ "$SKIP_BUILD" = false ]; then
+ echo "[2/3] Building..."
+ cmake --build "$BUILD_DIR" -j4 --target cnn_test >/dev/null 2>&1
+fi
+
+# Process
+echo "[3/3] Processing images..."
+mkdir -p "$OUTPUT"
+count=0
+for img in "$TEST_IMAGES"/*.png; do
+ [ -f "$img" ] || continue
+ name=$(basename "$img" .png)
+ "$BUILD_DIR/cnn_test" "$img" "$OUTPUT/${name}_output.png" 2>/dev/null && count=$((count+1))
+done
+
+echo "Done! Processed $count images → $OUTPUT"