diff options
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 8cac51a..758b044 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -312,12 +312,13 @@ def train(args): avg_loss = epoch_loss / len(dataloader) - if epoch % 100 == 0 or epoch == 1: - elapsed = time.time() - start_time - print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s") + # Print loss at every epoch (overwrite line with \r) + elapsed = time.time() - start_time + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() # Newline before checkpoint message checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ |
