summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 10:39:25 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 10:39:25 +0100
commit303286f34866a232bc18e2f2932ba57718fafbd5 (patch)
tree83a76cb1c59842b4f72592a4d4a397f24a775e5c /training/train_cnn.py
parent5515301560451549f228867a72ca850cffeb3714 (diff)
feat: Add checkpointing support to CNN training script
- Save checkpoints every N epochs (--checkpoint-every) - Resume from checkpoint (--resume) - Store model, optimizer, epoch, loss, and architecture info - Auto-create checkpoint directory Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py35
1 files changed, 33 insertions, 2 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index c249947..82f0b48 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -246,9 +246,22 @@ def train(args):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+ # Resume from checkpoint
+ start_epoch = 0
+ if args.resume:
+ if os.path.exists(args.resume):
+ print(f"Loading checkpoint from {args.resume}...")
+ checkpoint = torch.load(args.resume, map_location=device)
+ model.load_state_dict(checkpoint['model_state'])
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
+ start_epoch = checkpoint['epoch'] + 1
+ print(f"Resumed from epoch {start_epoch}")
+ else:
+ print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
+
# Training loop
- print(f"\nTraining for {args.epochs} epochs...")
- for epoch in range(args.epochs):
+ print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
+ for epoch in range(start_epoch, args.epochs):
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs, targets = inputs.to(device), targets.to(device)
@@ -265,6 +278,21 @@ def train(args):
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+ # Save checkpoint
+ if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
+ checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
+ torch.save({
+ 'epoch': epoch,
+ 'model_state': model.state_dict(),
+ 'optimizer_state': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.layers
+ }, checkpoint_path)
+ print(f"Saved checkpoint to {checkpoint_path}")
+
# Export weights
output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
print(f"\nExporting weights to {output_path}...")
@@ -284,6 +312,9 @@ def main():
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+ parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
+ parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
+ parser.add_argument('--resume', help='Resume from checkpoint file')
args = parser.parse_args()