summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 8cac51a..758b044 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -312,12 +312,13 @@ def train(args):
avg_loss = epoch_loss / len(dataloader)
- if epoch % 100 == 0 or epoch == 1:
- elapsed = time.time() - start_time
- print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s")
+ # Print loss at every epoch (overwrite line with \r)
+ elapsed = time.time() - start_time
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
# Save checkpoint
if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print() # Newline before checkpoint message
checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({