diff options
Diffstat (limited to 'scripts')
| -rwxr-xr-x | scripts/train_cnn_v2_full.sh | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index 9c235b6..c98ff2d 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -31,6 +31,9 @@ # --checkpoint-dir DIR Checkpoint directory (default: checkpoints) # --validation-dir DIR Validation directory (default: validation_results) # +# OUTPUT: +# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin) +# # OTHER: # --help Show this help message # @@ -77,6 +80,7 @@ MIP_LEVEL=0 GRAYSCALE_LOSS=false FULL_IMAGE_MODE=false IMAGE_SIZE=256 +OUTPUT_WEIGHTS="workspaces/main/weights/cnn_v2_weights.bin" # Parse arguments VALIDATE_ONLY=false @@ -230,6 +234,14 @@ while [[ $# -gt 0 ]]; do VALIDATION_DIR="$2" shift 2 ;; + --output-weights) + if [ -z "$2" ]; then + echo "Error: --output-weights requires a file path argument" + exit 1 + fi + OUTPUT_WEIGHTS="$2" + shift 2 + ;; *) echo "Unknown option: $1" exit 1 @@ -255,14 +267,14 @@ if [ "$EXPORT_ONLY" = true ]; then exit 1 fi - export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { + export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || { echo "Error: Export failed" exit 1 } echo "" echo "=== Export Complete ===" - echo "Output: workspaces/main/weights/cnn_v2_weights.bin" + echo "Output: $OUTPUT_WEIGHTS" exit 0 fi @@ -316,7 +328,7 @@ fi echo "[2/4] Exporting final checkpoint to binary weights..." echo "Checkpoint: $FINAL_CHECKPOINT" -export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { +export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || { echo "Error: Shader export failed" exit 1 } @@ -354,7 +366,7 @@ echo " Using checkpoint: $FINAL_CHECKPOINT" # Export weights for validation mode (already exported in step 2 for training mode) if [ "$VALIDATE_ONLY" = true ]; then - export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1 + export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1 fi # Build cnn_test @@ -364,7 +376,7 @@ build_target cnn_test for input_image in "$INPUT_DIR"/*.png; do basename=$(basename "$input_image" .png) echo " Processing $basename..." - build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null + build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" 2>/dev/null done # Build demo only if not in validate mode @@ -380,7 +392,7 @@ echo "" echo "Results:" if [ "$VALIDATE_ONLY" = false ]; then echo " - Checkpoints: $CHECKPOINT_DIR" - echo " - Final weights: workspaces/main/weights/cnn_v2_weights.bin" + echo " - Final weights: $OUTPUT_WEIGHTS" fi echo " - Validation outputs: $VALIDATION_DIR" echo "" |
