summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/train_cnn_v2_full.sh24
1 files changed, 18 insertions, 6 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 9c235b6..c98ff2d 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -31,6 +31,9 @@
# --checkpoint-dir DIR Checkpoint directory (default: checkpoints)
# --validation-dir DIR Validation directory (default: validation_results)
#
+# OUTPUT:
+# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin)
+#
# OTHER:
# --help Show this help message
#
@@ -77,6 +80,7 @@ MIP_LEVEL=0
GRAYSCALE_LOSS=false
FULL_IMAGE_MODE=false
IMAGE_SIZE=256
+OUTPUT_WEIGHTS="workspaces/main/weights/cnn_v2_weights.bin"
# Parse arguments
VALIDATE_ONLY=false
@@ -230,6 +234,14 @@ while [[ $# -gt 0 ]]; do
VALIDATION_DIR="$2"
shift 2
;;
+ --output-weights)
+ if [ -z "$2" ]; then
+ echo "Error: --output-weights requires a file path argument"
+ exit 1
+ fi
+ OUTPUT_WEIGHTS="$2"
+ shift 2
+ ;;
*)
echo "Unknown option: $1"
exit 1
@@ -255,14 +267,14 @@ if [ "$EXPORT_ONLY" = true ]; then
exit 1
fi
- export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
+ export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
echo "Error: Export failed"
exit 1
}
echo ""
echo "=== Export Complete ==="
- echo "Output: workspaces/main/weights/cnn_v2_weights.bin"
+ echo "Output: $OUTPUT_WEIGHTS"
exit 0
fi
@@ -316,7 +328,7 @@ fi
echo "[2/4] Exporting final checkpoint to binary weights..."
echo "Checkpoint: $FINAL_CHECKPOINT"
-export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
+export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
echo "Error: Shader export failed"
exit 1
}
@@ -354,7 +366,7 @@ echo " Using checkpoint: $FINAL_CHECKPOINT"
# Export weights for validation mode (already exported in step 2 for training mode)
if [ "$VALIDATE_ONLY" = true ]; then
- export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1
+ export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1
fi
# Build cnn_test
@@ -364,7 +376,7 @@ build_target cnn_test
for input_image in "$INPUT_DIR"/*.png; do
basename=$(basename "$input_image" .png)
echo " Processing $basename..."
- build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" 2>/dev/null
done
# Build demo only if not in validate mode
@@ -380,7 +392,7 @@ echo ""
echo "Results:"
if [ "$VALIDATE_ONLY" = false ]; then
echo " - Checkpoints: $CHECKPOINT_DIR"
- echo " - Final weights: workspaces/main/weights/cnn_v2_weights.bin"
+ echo " - Final weights: $OUTPUT_WEIGHTS"
fi
echo " - Validation outputs: $VALIDATION_DIR"
echo ""