diff options
| -rwxr-xr-x | scripts/train_cnn_v2_full.sh | 62 |
1 files changed, 26 insertions, 36 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index e444f20..9c235b6 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -47,6 +47,19 @@ set -e PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$PROJECT_ROOT" +# Helper functions +export_weights() { + python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2" +} + +find_latest_checkpoint() { + ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth 2>/dev/null | head -1 +} + +build_target() { + cmake --build build -j4 --target "$1" > /dev/null 2>&1 +} + # Default configuration INPUT_DIR="training/input" TARGET_DIR="training/target_1" @@ -242,13 +255,10 @@ if [ "$EXPORT_ONLY" = true ]; then exit 1 fi - python3 training/export_cnn_v2_weights.py "$EXPORT_CHECKPOINT" \ - --output-weights workspaces/main/weights/cnn_v2_weights.bin - - if [ $? -ne 0 ]; then + export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { echo "Error: Export failed" exit 1 - fi + } echo "" echo "=== Export Complete ===" @@ -274,12 +284,6 @@ if [ "$VALIDATE_ONLY" = false ]; then # Step 1: Train model echo "[1/4] Training CNN v2 model..." -# Build optional flags -OPTIONAL_FLAGS="" -if [ "$GRAYSCALE_LOSS" = true ]; then - OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss" -fi - python3 training/train_cnn_v2.py \ --input "$INPUT_DIR" \ --target "$TARGET_DIR" \ @@ -291,7 +295,7 @@ python3 training/train_cnn_v2.py \ --batch-size $BATCH_SIZE \ --checkpoint-dir "$CHECKPOINT_DIR" \ --checkpoint-every $CHECKPOINT_EVERY \ - $OPTIONAL_FLAGS + $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss") if [ $? -ne 0 ]; then echo "Error: Training failed" @@ -307,30 +311,22 @@ FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth" if [ ! -f "$FINAL_CHECKPOINT" ]; then echo "Warning: Final checkpoint not found, using latest available..." - FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1) + FINAL_CHECKPOINT=$(find_latest_checkpoint) fi echo "[2/4] Exporting final checkpoint to binary weights..." echo "Checkpoint: $FINAL_CHECKPOINT" -python3 training/export_cnn_v2_weights.py "$FINAL_CHECKPOINT" \ - --output-weights workspaces/main/weights/cnn_v2_weights.bin - -if [ $? -ne 0 ]; then +export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || { echo "Error: Shader export failed" exit 1 -fi +} echo "" fi # End of training/export section # Determine which checkpoint to use if [ "$VALIDATE_ONLY" = true ]; then - if [ -n "$VALIDATE_CHECKPOINT" ]; then - FINAL_CHECKPOINT="$VALIDATE_CHECKPOINT" - else - # Use latest checkpoint - FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1) - fi + FINAL_CHECKPOINT="${VALIDATE_CHECKPOINT:-$(find_latest_checkpoint)}" echo "Using checkpoint: $FINAL_CHECKPOINT" echo "" fi @@ -338,13 +334,10 @@ fi # Step 3: Rebuild with new shaders if [ "$VALIDATE_ONLY" = false ]; then echo "[3/4] Rebuilding demo with new shaders..." - cmake --build build -j4 --target demo64k > /dev/null 2>&1 - - if [ $? -ne 0 ]; then + build_target demo64k || { echo "Error: Build failed" exit 1 - fi - + } echo " → Build complete" echo "" fi @@ -361,24 +354,21 @@ echo " Using checkpoint: $FINAL_CHECKPOINT" # Export weights for validation mode (already exported in step 2 for training mode) if [ "$VALIDATE_ONLY" = true ]; then - python3 training/export_cnn_v2_weights.py "$FINAL_CHECKPOINT" \ - --output-weights workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1 + export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1 fi # Build cnn_test -cmake --build build -j4 --target cnn_test > /dev/null 2>&1 +build_target cnn_test # Process all input images for input_image in "$INPUT_DIR"/*.png; do basename=$(basename "$input_image" .png) echo " Processing $basename..." - build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" 2>/dev/null + build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null done # Build demo only if not in validate mode -if [ "$VALIDATE_ONLY" = false ]; then - cmake --build build -j4 --target demo64k > /dev/null 2>&1 -fi +[ "$VALIDATE_ONLY" = false ] && build_target demo64k echo "" if [ "$VALIDATE_ONLY" = true ]; then |
