diff options
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
| -rwxr-xr-x | scripts/train_cnn_v2_full.sh | 58 |
1 files changed, 44 insertions, 14 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index 9c235b6..078ea28 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -12,6 +12,7 @@ # TRAINING PARAMETERS: # --epochs N Training epochs (default: 200) # --batch-size N Batch size (default: 16) +# --lr FLOAT Learning rate (default: 1e-3) # --checkpoint-every N Checkpoint interval (default: 50) # --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3) # --num-layers N Number of layers (default: 3) @@ -31,6 +32,9 @@ # --checkpoint-dir DIR Checkpoint directory (default: checkpoints) # --validation-dir DIR Validation directory (default: validation_results) # +# OUTPUT: +# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin) +# # OTHER: # --help Show this help message # @@ -49,7 +53,7 @@ cd "$PROJECT_ROOT" # Helper functions export_weights() { - python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2" + python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2" --quiet } find_latest_checkpoint() { @@ -68,6 +72,7 @@ VALIDATION_DIR="validation_results" EPOCHS=200 CHECKPOINT_EVERY=50 BATCH_SIZE=16 +LEARNING_RATE=1e-3 PATCH_SIZE=8 PATCHES_PER_IMAGE=256 DETECTOR="harris" @@ -77,6 +82,7 @@ MIP_LEVEL=0 GRAYSCALE_LOSS=false FULL_IMAGE_MODE=false IMAGE_SIZE=256 +OUTPUT_WEIGHTS="workspaces/main/weights/cnn_v2_weights.bin" # Parse arguments VALIDATE_ONLY=false @@ -162,6 +168,14 @@ while [[ $# -gt 0 ]]; do GRAYSCALE_LOSS=true shift ;; + --lr) + if [ -z "$2" ]; then + echo "Error: --lr requires a float argument" + exit 1 + fi + LEARNING_RATE="$2" + shift 2 + ;; --patch-size) if [ -z "$2" ]; then echo "Error: --patch-size requires a number argument" @@ -230,6 +244,14 @@ while [[ $# -gt 0 ]]; do VALIDATION_DIR="$2" shift 2 ;; + --output-weights) + if [ -z "$2" ]; then + echo "Error: --output-weights requires a file path argument" + exit 1 + fi + OUTPUT_WEIGHTS="$2" + shift 2 + ;; *) echo "Unknown option: $1" exit 1 @@ -255,14 +277,14 @@ if [ "$EXPORT_ONLY" = true ]; then exit 1 fi - export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { + export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || { echo "Error: Export failed" exit 1 } echo "" echo "=== Export Complete ===" - echo "Output: workspaces/main/weights/cnn_v2_weights.bin" + echo "Output: $OUTPUT_WEIGHTS" exit 0 fi @@ -288,13 +310,14 @@ python3 training/train_cnn_v2.py \ --input "$INPUT_DIR" \ --target "$TARGET_DIR" \ $TRAINING_MODE_ARGS \ - --kernel-sizes $KERNEL_SIZES \ - --num-layers $NUM_LAYERS \ - --mip-level $MIP_LEVEL \ - --epochs $EPOCHS \ - --batch-size $BATCH_SIZE \ + --kernel-sizes "$KERNEL_SIZES" \ + --num-layers "$NUM_LAYERS" \ + --mip-level "$MIP_LEVEL" \ + --epochs "$EPOCHS" \ + --batch-size "$BATCH_SIZE" \ + --lr "$LEARNING_RATE" \ --checkpoint-dir "$CHECKPOINT_DIR" \ - --checkpoint-every $CHECKPOINT_EVERY \ + --checkpoint-every "$CHECKPOINT_EVERY" \ $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss") if [ $? -ne 0 ]; then @@ -314,9 +337,14 @@ if [ ! -f "$FINAL_CHECKPOINT" ]; then FINAL_CHECKPOINT=$(find_latest_checkpoint) fi +if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then + echo "Error: No checkpoint found in $CHECKPOINT_DIR" + exit 1 +fi + echo "[2/4] Exporting final checkpoint to binary weights..." echo "Checkpoint: $FINAL_CHECKPOINT" -export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { +export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || { echo "Error: Shader export failed" exit 1 } @@ -354,18 +382,20 @@ echo " Using checkpoint: $FINAL_CHECKPOINT" # Export weights for validation mode (already exported in step 2 for training mode) if [ "$VALIDATE_ONLY" = true ]; then - export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1 + export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1 fi # Build cnn_test build_target cnn_test # Process all input images +echo -n " Processing images: " for input_image in "$INPUT_DIR"/*.png; do basename=$(basename "$input_image" .png) - echo " Processing $basename..." - build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null + echo -n "$basename " + build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" > /dev/null 2>&1 done +echo "✓" # Build demo only if not in validate mode [ "$VALIDATE_ONLY" = false ] && build_target demo64k @@ -380,7 +410,7 @@ echo "" echo "Results:" if [ "$VALIDATE_ONLY" = false ]; then echo " - Checkpoints: $CHECKPOINT_DIR" - echo " - Final weights: workspaces/main/weights/cnn_v2_weights.bin" + echo " - Final weights: $OUTPUT_WEIGHTS" fi echo " - Validation outputs: $VALIDATION_DIR" echo "" |
