summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 19:59:46 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 19:59:46 +0100
commita370fcc440d99f728dfc4832e94c758e0b92bc63 (patch)
treebc67ba2bda96b59cf3db10b95dccf0be8c9045b6
parentedd549e1527444ae9c74c70f1e3e44b11862f3da (diff)
CNN v2 training: Refactor train_cnn_v2_full.sh for maintainability
- Add helper functions: export_weights(), find_latest_checkpoint(), build_target() - Eliminate duplicate export logic (3 instances → 1 function) - Eliminate duplicate checkpoint finding (2 instances → 1 function) - Consolidate build commands (4 instances → 1 function) - Simplify optional flags with inline command substitution - Fix validation mode: correct cnn_test argument order (positional args before --cnn-version) - 30 fewer lines, improved readability handoff(Claude): Refactored CNN v2 training script, fixed validation bug
-rwxr-xr-xscripts/train_cnn_v2_full.sh62
1 files changed, 26 insertions, 36 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index e444f20..9c235b6 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -47,6 +47,19 @@ set -e
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_ROOT"
+# Helper functions
+export_weights() {
+ python3 training/export_cnn_v2_weights.py "$1" --output-weights "$2"
+}
+
+find_latest_checkpoint() {
+ ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth 2>/dev/null | head -1
+}
+
+build_target() {
+ cmake --build build -j4 --target "$1" > /dev/null 2>&1
+}
+
# Default configuration
INPUT_DIR="training/input"
TARGET_DIR="training/target_1"
@@ -242,13 +255,10 @@ if [ "$EXPORT_ONLY" = true ]; then
exit 1
fi
- python3 training/export_cnn_v2_weights.py "$EXPORT_CHECKPOINT" \
- --output-weights workspaces/main/weights/cnn_v2_weights.bin
-
- if [ $? -ne 0 ]; then
+ export_weights "$EXPORT_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
echo "Error: Export failed"
exit 1
- fi
+ }
echo ""
echo "=== Export Complete ==="
@@ -274,12 +284,6 @@ if [ "$VALIDATE_ONLY" = false ]; then
# Step 1: Train model
echo "[1/4] Training CNN v2 model..."
-# Build optional flags
-OPTIONAL_FLAGS=""
-if [ "$GRAYSCALE_LOSS" = true ]; then
- OPTIONAL_FLAGS="$OPTIONAL_FLAGS --grayscale-loss"
-fi
-
python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
@@ -291,7 +295,7 @@ python3 training/train_cnn_v2.py \
--batch-size $BATCH_SIZE \
--checkpoint-dir "$CHECKPOINT_DIR" \
--checkpoint-every $CHECKPOINT_EVERY \
- $OPTIONAL_FLAGS
+ $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
if [ $? -ne 0 ]; then
echo "Error: Training failed"
@@ -307,30 +311,22 @@ FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth"
if [ ! -f "$FINAL_CHECKPOINT" ]; then
echo "Warning: Final checkpoint not found, using latest available..."
- FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1)
+ FINAL_CHECKPOINT=$(find_latest_checkpoint)
fi
echo "[2/4] Exporting final checkpoint to binary weights..."
echo "Checkpoint: $FINAL_CHECKPOINT"
-python3 training/export_cnn_v2_weights.py "$FINAL_CHECKPOINT" \
- --output-weights workspaces/main/weights/cnn_v2_weights.bin
-
-if [ $? -ne 0 ]; then
+export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin || {
echo "Error: Shader export failed"
exit 1
-fi
+}
echo ""
fi # End of training/export section
# Determine which checkpoint to use
if [ "$VALIDATE_ONLY" = true ]; then
- if [ -n "$VALIDATE_CHECKPOINT" ]; then
- FINAL_CHECKPOINT="$VALIDATE_CHECKPOINT"
- else
- # Use latest checkpoint
- FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1)
- fi
+ FINAL_CHECKPOINT="${VALIDATE_CHECKPOINT:-$(find_latest_checkpoint)}"
echo "Using checkpoint: $FINAL_CHECKPOINT"
echo ""
fi
@@ -338,13 +334,10 @@ fi
# Step 3: Rebuild with new shaders
if [ "$VALIDATE_ONLY" = false ]; then
echo "[3/4] Rebuilding demo with new shaders..."
- cmake --build build -j4 --target demo64k > /dev/null 2>&1
-
- if [ $? -ne 0 ]; then
+ build_target demo64k || {
echo "Error: Build failed"
exit 1
- fi
-
+ }
echo " → Build complete"
echo ""
fi
@@ -361,24 +354,21 @@ echo " Using checkpoint: $FINAL_CHECKPOINT"
# Export weights for validation mode (already exported in step 2 for training mode)
if [ "$VALIDATE_ONLY" = true ]; then
- python3 training/export_cnn_v2_weights.py "$FINAL_CHECKPOINT" \
- --output-weights workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1
+ export_weights "$FINAL_CHECKPOINT" workspaces/main/weights/cnn_v2_weights.bin > /dev/null 2>&1
fi
# Build cnn_test
-cmake --build build -j4 --target cnn_test > /dev/null 2>&1
+build_target cnn_test
# Process all input images
for input_image in "$INPUT_DIR"/*.png; do
basename=$(basename "$input_image" .png)
echo " Processing $basename..."
- build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" 2>/dev/null
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --cnn-version 2 2>/dev/null
done
# Build demo only if not in validate mode
-if [ "$VALIDATE_ONLY" = false ]; then
- cmake --build build -j4 --target demo64k > /dev/null 2>&1
-fi
+[ "$VALIDATE_ONLY" = false ] && build_target demo64k
echo ""
if [ "$VALIDATE_ONLY" = true ]; then