summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtraining/train_cnn.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 89c50d5..d974ce7 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -540,9 +540,16 @@ def train(args):
num_layers = args.layers
border = num_layers # Each 3x3 layer needs 1px, accumulates across layers
+ # Early stopping setup
+ loss_history = []
+ early_stop_triggered = False
+
# Training loop
print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
print(f"Computing loss on center region only (excluding {border}px border)")
+ if args.early_stop_patience > 0:
+ print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}")
+
for epoch in range(start_epoch, args.epochs):
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
@@ -568,6 +575,18 @@ def train(args):
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+ # Early stopping check
+ if args.early_stop_patience > 0:
+ loss_history.append(avg_loss)
+ if len(loss_history) >= args.early_stop_patience:
+ oldest_loss = loss_history[-args.early_stop_patience]
+ loss_change = abs(avg_loss - oldest_loss)
+ if loss_change < args.early_stop_eps:
+ print(f"Early stopping triggered at epoch {epoch+1}")
+ print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}")
+ early_stop_triggered = True
+ break
+
# Save checkpoint
if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
@@ -741,6 +760,8 @@ def main():
parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)')
parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
help='Salient point detector for patch extraction (default: harris)')
+ parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
+ parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
args = parser.parse_args()