diff options
| -rwxr-xr-x | training/train_cnn.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 89c50d5..d974ce7 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -540,9 +540,16 @@ def train(args): num_layers = args.layers border = num_layers # Each 3x3 layer needs 1px, accumulates across layers + # Early stopping setup + loss_history = [] + early_stop_triggered = False + # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") print(f"Computing loss on center region only (excluding {border}px border)") + if args.early_stop_patience > 0: + print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}") + for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): @@ -568,6 +575,18 @@ def train(args): if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") + # Early stopping check + if args.early_stop_patience > 0: + loss_history.append(avg_loss) + if len(loss_history) >= args.early_stop_patience: + oldest_loss = loss_history[-args.early_stop_patience] + loss_change = abs(avg_loss - oldest_loss) + if loss_change < args.early_stop_eps: + print(f"Early stopping triggered at epoch {epoch+1}") + print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}") + early_stop_triggered = True + break + # Save checkpoint if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' @@ -741,6 +760,8 @@ def main(): parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)') parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Salient point detector for patch extraction (default: harris)') + parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') + parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') args = parser.parse_args() |
