summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 134a5ae..d80e3a5 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -329,6 +329,9 @@ def train(args):
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
if len(kernel_sizes) == 1:
kernel_sizes = kernel_sizes * args.num_layers
+ else:
+ # When multiple kernel sizes provided, derive num_layers from list length
+ args.num_layers = len(kernel_sizes)
# Create model
model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
@@ -397,6 +400,25 @@ def train(args):
}, checkpoint_path)
print(f" → Saved checkpoint: {checkpoint_path}")
+ # Always save final checkpoint
+ print() # Newline after training
+ final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
+ final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': args.epochs,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
+ }
+ }, final_checkpoint)
+ print(f" → Saved final checkpoint: {final_checkpoint}")
+
print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
return model