summaryrefslogtreecommitdiff
path: root/scripts/train_cnn_v2_full.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/train_cnn_v2_full.sh')
-rwxr-xr-xscripts/train_cnn_v2_full.sh58
1 files changed, 44 insertions, 14 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 9c235b6..078ea28 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -12,6 +12,7 @@
# TRAINING PARAMETERS:
# --epochs N Training epochs (default: 200)
# --batch-size N Batch size (default: 16)
+# --lr FLOAT Learning rate (default: 1e-3)
# --checkpoint-every N Checkpoint interval (default: 50)
# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
# --num-layers N Number of layers (default: 3)
@@ -31,6 +32,9 @@
# --checkpoint-dir DIR Checkpoint directory (default: checkpoints)
# --validation-dir DIR Validation directory (default: validation_results)
#
+# OUTPUT:
+# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin)
+#
# OTHER:
# --help Show this help message
#
@@ -49,7 +53,7 @@ cd "$PROJECT_ROOT"
# Helper functions
export_weights() {
- python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2"
+ python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2" --quiet
}
find_latest_checkpoint() {
@@ -68,6 +72,7 @@ VALIDATION_DIR="validation_results"
EPOCHS=200
CHECKPOINT_EVERY=50
BATCH_SIZE=16
+LEARNING_RATE=1e-3
PATCH_SIZE=8
PATCHES_PER_IMAGE=256
DETECTOR="harris"
@@ -77,6 +82,7 @@ MIP_LEVEL=0
GRAYSCALE_LOSS=false
FULL_IMAGE_MODE=false
IMAGE_SIZE=256
+OUTPUT_WEIGHTS="workspaces/main/weights/cnn_v2_weights.bin"
# Parse arguments
VALIDATE_ONLY=false
@@ -162,6 +168,14 @@ while [[ $# -gt 0 ]]; do
GRAYSCALE_LOSS=true
shift
;;
+ --lr)
+ if [ -z "$2" ]; then
+ echo "Error: --lr requires a float argument"
+ exit 1
+ fi
+ LEARNING_RATE="$2"
+ shift 2
+ ;;
--patch-size)
if [ -z "$2" ]; then
echo "Error: --patch-size requires a number argument"
@@ -230,6 +244,14 @@ while [[ $# -gt 0 ]]; do
VALIDATION_DIR="$2"
shift 2
;;
+ --output-weights)
+ if [ -z "$2" ]; then
+ echo "Error: --output-weights requires a file path argument"
+ exit 1
+ fi
+ OUTPUT_WEIGHTS="$2"
+ shift 2
+ ;;
*)
echo "Unknown option: $1"
exit 1
@@ -255,14 +277,14 @@ if [ "$EXPORT_ONLY" = true ]; then
exit 1
fi
- export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
+ export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
echo "Error: Export failed"
exit 1
}
echo ""
echo "=== Export Complete ==="
- echo "Output: workspaces/main/weights/cnn_v2_weights.bin"
+ echo "Output: $OUTPUT_WEIGHTS"
exit 0
fi
@@ -288,13 +310,14 @@ python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
$TRAINING_MODE_ARGS \
- --kernel-sizes $KERNEL_SIZES \
- --num-layers $NUM_LAYERS \
- --mip-level $MIP_LEVEL \
- --epochs $EPOCHS \
- --batch-size $BATCH_SIZE \
+ --kernel-sizes "$KERNEL_SIZES" \
+ --num-layers "$NUM_LAYERS" \
+ --mip-level "$MIP_LEVEL" \
+ --epochs "$EPOCHS" \
+ --batch-size "$BATCH_SIZE" \
+ --lr "$LEARNING_RATE" \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY \
+ --checkpoint-every "$CHECKPOINT_EVERY" \
$([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
if [ $? -ne 0 ]; then
@@ -314,9 +337,14 @@ if [ ! -f "$FINAL_CHECKPOINT" ]; then
FINAL_CHECKPOINT=$(find_latest_checkpoint)
fi
+if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Error: No checkpoint found in $CHECKPOINT_DIR"
+ exit 1
+fi
+
echo "[2/4] Exporting final checkpoint to binary weights..."
echo "Checkpoint: $FINAL_CHECKPOINT"
-export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
+export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
echo "Error: Shader export failed"
exit 1
}
@@ -354,18 +382,20 @@ echo " Using checkpoint: $FINAL_CHECKPOINT"
# Export weights for validation mode (already exported in step 2 for training mode)
if [ "$VALIDATE_ONLY" = true ]; then
- export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1
+ export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1
fi
# Build cnn_test
build_target cnn_test
# Process all input images
+echo -n " Processing images: "
for input_image in "$INPUT_DIR"/*.png; do
basename=$(basename "$input_image" .png)
- echo " Processing $basename..."
- build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null
+ echo -n "$basename "
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" > /dev/null 2>&1
done
+echo "✓"
# Build demo only if not in validate mode
[ "$VALIDATE_ONLY" = false ] && build_target demo64k
@@ -380,7 +410,7 @@ echo ""
echo "Results:"
if [ "$VALIDATE_ONLY" = false ]; then
echo " - Checkpoints: $CHECKPOINT_DIR"
- echo " - Final weights: workspaces/main/weights/cnn_v2_weights.bin"
+ echo " - Final weights: $OUTPUT_WEIGHTS"
fi
echo " - Validation outputs: $VALIDATION_DIR"
echo ""