From 303286f34866a232bc18e2f2932ba57718fafbd5 Mon Sep 17 00:00:00 2001 From: skal Date: Tue, 10 Feb 2026 10:39:25 +0100 Subject: feat: Add checkpointing support to CNN training script - Save checkpoints every N epochs (--checkpoint-every) - Resume from checkpoint (--resume) - Store model, optimizer, epoch, loss, and architecture info - Auto-create checkpoint directory Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) (limited to 'training/train_cnn.py') diff --git a/training/train_cnn.py b/training/train_cnn.py index c249947..82f0b48 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -246,9 +246,22 @@ def train(args): criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) + # Resume from checkpoint + start_epoch = 0 + if args.resume: + if os.path.exists(args.resume): + print(f"Loading checkpoint from {args.resume}...") + checkpoint = torch.load(args.resume, map_location=device) + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Resumed from epoch {start_epoch}") + else: + print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") + # Training loop - print(f"\nTraining for {args.epochs} epochs...") - for epoch in range(args.epochs): + print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") + for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) @@ -265,6 +278,21 @@ def train(args): if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") + # Save checkpoint + if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: + checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') + torch.save({ + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict(), + 'loss': avg_loss, + 'kernel_sizes': kernel_sizes, + 'num_layers': args.layers + }, checkpoint_path) + print(f"Saved checkpoint to {checkpoint_path}") + # Export weights output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") @@ -284,6 +312,9 @@ def main(): parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') + parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') + parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') + parser.add_argument('--resume', help='Resume from checkpoint file') args = parser.parse_args() -- cgit v1.2.3