summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 12:17:59 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 12:17:59 +0100
commitff4c1213636e66d4457a95cad12300c58e8d6781 (patch)
treeb47a9fc5c860c4eff39054b2ffc248ffbe19fa10 /training
parenteaf0bd855306e70ca03f2d6579b4d6551aff6482 (diff)
Refine training script output and validation
1. Loss printed at every epoch with \r (no scrolling) 2. Validation only on final epoch (not all checkpoints) 3. Process all input images (not just img_000.png) Training output now shows live progress with single line update.
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 8cac51a..758b044 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -312,12 +312,13 @@ def train(args):
avg_loss = epoch_loss / len(dataloader)
- if epoch % 100 == 0 or epoch == 1:
- elapsed = time.time() - start_time
- print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s")
+ # Print loss at every epoch (overwrite line with \r)
+ elapsed = time.time() - start_time
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
# Save checkpoint
if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print() # Newline before checkpoint message
checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({