summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-14 01:33:37 +0100
committerskal <pascal.massimino@gmail.com>2026-02-14 01:33:37 +0100
commit4f107cc2d215a8bff69ea85eb10ee91920e797a3 (patch)
tree0204ff67cec9c5a3347692cd2511c5bda9d541f3
parent3ef1e484ff1328ac51511a8a8ccab397392a8491 (diff)
Fix CNN v2 training: always save final checkpoint, derive num_layers
- Always save final checkpoint after training completes - Derive num_layers from kernel_sizes list when multiple values provided - Add checkpoint validation in training pipeline script - Quote shell variables when passing args to Python Fixes issue where no checkpoint saved when epochs < checkpoint_every. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rwxr-xr-xscripts/train_cnn_v2_full.sh17
-rwxr-xr-xtraining/train_cnn_v2.py22
2 files changed, 33 insertions, 6 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index c98ff2d..1c683c2 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -300,13 +300,13 @@ python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
$TRAINING_MODE_ARGS \
- --kernel-sizes $KERNEL_SIZES \
- --num-layers $NUM_LAYERS \
- --mip-level $MIP_LEVEL \
- --epochs $EPOCHS \
- --batch-size $BATCH_SIZE \
+ --kernel-sizes "$KERNEL_SIZES" \
+ --num-layers "$NUM_LAYERS" \
+ --mip-level "$MIP_LEVEL" \
+ --epochs "$EPOCHS" \
+ --batch-size "$BATCH_SIZE" \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY \
+ --checkpoint-every "$CHECKPOINT_EVERY" \
$([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
if [ $? -ne 0 ]; then
@@ -326,6 +326,11 @@ if [ ! -f "$FINAL_CHECKPOINT" ]; then
FINAL_CHECKPOINT=$(find_latest_checkpoint)
fi
+if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Error: No checkpoint found in $CHECKPOINT_DIR"
+ exit 1
+fi
+
echo "[2/4] Exporting final checkpoint to binary weights..."
echo "Checkpoint: $FINAL_CHECKPOINT"
export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 134a5ae..d80e3a5 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -329,6 +329,9 @@ def train(args):
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
if len(kernel_sizes) == 1:
kernel_sizes = kernel_sizes * args.num_layers
+ else:
+ # When multiple kernel sizes provided, derive num_layers from list length
+ args.num_layers = len(kernel_sizes)
# Create model
model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
@@ -397,6 +400,25 @@ def train(args):
}, checkpoint_path)
print(f" → Saved checkpoint: {checkpoint_path}")
+ # Always save final checkpoint
+ print() # Newline after training
+ final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
+ final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': args.epochs,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
+ }
+ }, final_checkpoint)
+ print(f" → Saved final checkpoint: {final_checkpoint}")
+
print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
return model