diff options
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn_v2.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 3ab1c0f..5c93f20 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -120,7 +120,15 @@ class PatchDataset(Dataset): return len(self.input_paths) * self.patches_per_image def _detect_salient_points(self, img_array): - """Detect salient points on original image.""" + """Detect salient points on original image. + + TODO: Add random sampling to training vectors + - In addition to salient points, incorporate randomly-located samples + - Default: 10% random samples, 90% salient points + - Prevents overfitting to only high-gradient regions + - Improves generalization across entire image + - Configurable via --random-sample-percent parameter + """ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) h, w = gray.shape half_patch = self.patch_size // 2 @@ -342,6 +350,8 @@ def main(): parser.add_argument('--detector', type=str, default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Patch mode: salient point detector (default: harris)') + # TODO: Add --random-sample-percent parameter (default: 10) + # Mix salient points with random samples for better generalization # Model architecture parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], |
