summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/train_cnn_v2_full.sh17
1 files changed, 11 insertions, 6 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index c98ff2d..1c683c2 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -300,13 +300,13 @@ python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
$TRAINING_MODE_ARGS \
- --kernel-sizes $KERNEL_SIZES \
- --num-layers $NUM_LAYERS \
- --mip-level $MIP_LEVEL \
- --epochs $EPOCHS \
- --batch-size $BATCH_SIZE \
+ --kernel-sizes "$KERNEL_SIZES" \
+ --num-layers "$NUM_LAYERS" \
+ --mip-level "$MIP_LEVEL" \
+ --epochs "$EPOCHS" \
+ --batch-size "$BATCH_SIZE" \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY \
+ --checkpoint-every "$CHECKPOINT_EVERY" \
$([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
if [ $? -ne 0 ]; then
@@ -326,6 +326,11 @@ if [ ! -f "$FINAL_CHECKPOINT" ]; then
FINAL_CHECKPOINT=$(find_latest_checkpoint)
fi
+if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Error: No checkpoint found in $CHECKPOINT_DIR"
+ exit 1
+fi
+
echo "[2/4] Exporting final checkpoint to binary weights..."
echo "Checkpoint: $FINAL_CHECKPOINT"
export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {