diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 12:17:59 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 12:17:59 +0100 |
| commit | ff4c1213636e66d4457a95cad12300c58e8d6781 (patch) | |
| tree | b47a9fc5c860c4eff39054b2ffc248ffbe19fa10 /training/train_cnn_v2.py | |
| parent | eaf0bd855306e70ca03f2d6579b4d6551aff6482 (diff) | |
Refine training script output and validation
1. Loss printed at every epoch with \r (no scrolling)
2. Validation only on final epoch (not all checkpoints)
3. Process all input images (not just img_000.png)
Training output now shows live progress with single line update.
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 8cac51a..758b044 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -312,12 +312,13 @@ def train(args): avg_loss = epoch_loss / len(dataloader) - if epoch % 100 == 0 or epoch == 1: - elapsed = time.time() - start_time - print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s") + # Print loss at every epoch (overwrite line with \r) + elapsed = time.time() - start_time + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() # Newline before checkpoint message checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ |
