summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 3ab1c0f..5c93f20 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -120,7 +120,15 @@ class PatchDataset(Dataset):
return len(self.input_paths) * self.patches_per_image
def _detect_salient_points(self, img_array):
- """Detect salient points on original image."""
+ """Detect salient points on original image.
+
+ TODO: Add random sampling to training vectors
+ - In addition to salient points, incorporate randomly-located samples
+ - Default: 10% random samples, 90% salient points
+ - Prevents overfitting to only high-gradient regions
+ - Improves generalization across entire image
+ - Configurable via --random-sample-percent parameter
+ """
gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
h, w = gray.shape
half_patch = self.patch_size // 2
@@ -342,6 +350,8 @@ def main():
parser.add_argument('--detector', type=str, default='harris',
choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
help='Patch mode: salient point detector (default: harris)')
+ # TODO: Add --random-sample-percent parameter (default: 10)
+ # Mix salient points with random samples for better generalization
# Model architecture
parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],