diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 11:50:52 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 11:50:52 +0100 |
| commit | 7547e8ff4744339b92650b6ef3ff7405befe4beb (patch) | |
| tree | 0388b064c6bb2fcb2346796f9d1134c5ed9214b5 | |
| parent | c878631f24ddb7514dd4db3d7ace6a0a296d4157 (diff) | |
CNN v2: Patch-based training as default (like CNN v1)
Salient point detection on original images with patch extraction.
Changes:
- Added PatchDataset class (harris/fast/shi-tomasi/gradient detectors)
- Detects salient points on ORIGINAL images (no resize)
- Extracts 32×32 patches around salient points
- Default: 64 patches/image, harris detector
- Batch size: 16 (512 patches per batch)
Training modes:
1. Patch-based (default): --patch-size 32 --patches-per-image 64 --detector harris
2. Full-image (option): --full-image --image-size 256
Benefits:
- Focuses training on interesting regions
- Handles variable image sizes naturally
- Matches CNN v1 workflow
- Better convergence with limited data (8 images → 512 patches)
Script updated:
- train_cnn_v2_full.sh: Patch-based by default
- Configuration exposed for easy switching
Example:
./scripts/train_cnn_v2_full.sh # Patch-based
# Edit script: uncomment FULL_IMAGE for resize mode
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
| -rwxr-xr-x | scripts/train_cnn_v2_full.sh | 20 | ||||
| -rwxr-xr-x | training/train_cnn_v2.py | 152 |
2 files changed, 161 insertions, 11 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh index 119b788..4ddd9ac 100755 --- a/scripts/train_cnn_v2_full.sh +++ b/scripts/train_cnn_v2_full.sh @@ -14,8 +14,17 @@ CHECKPOINT_DIR="checkpoints" VALIDATION_DIR="validation_results" EPOCHS=10000 CHECKPOINT_EVERY=500 -BATCH_SIZE=8 -IMAGE_SIZE=256 +BATCH_SIZE=16 + +# Patch-based training (default) +PATCH_SIZE=32 +PATCHES_PER_IMAGE=64 +DETECTOR="harris" + +# Full-image training (alternative - uncomment to use) +# FULL_IMAGE="--full-image" +# IMAGE_SIZE=256 + KERNEL_SIZES="1 3 5" CHANNELS="16 8 4" @@ -31,13 +40,16 @@ echo "[1/4] Training CNN v2 model..." python3 training/train_cnn_v2.py \ --input "$INPUT_DIR" \ --target "$TARGET_DIR" \ - --image-size $IMAGE_SIZE \ + --patch-size $PATCH_SIZE \ + --patches-per-image $PATCHES_PER_IMAGE \ + --detector $DETECTOR \ --kernel-sizes $KERNEL_SIZES \ --channels $CHANNELS \ --epochs $EPOCHS \ --batch-size $BATCH_SIZE \ --checkpoint-dir "$CHECKPOINT_DIR" \ - --checkpoint-every $CHECKPOINT_EVERY + --checkpoint-every $CHECKPOINT_EVERY \ + $FULL_IMAGE if [ $? -ne 0 ]; then echo "Error: Training failed" diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index e590b40..3ab1c0f 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -17,6 +17,7 @@ from torch.utils.data import Dataset, DataLoader from pathlib import Path from PIL import Image import time +import cv2 def compute_static_features(rgb, depth=None): @@ -97,8 +98,120 @@ class CNNv2(nn.Module): return torch.sigmoid(output) +class PatchDataset(Dataset): + """Patch-based dataset extracting salient regions from images.""" + + def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, + detector='harris'): + self.input_paths = sorted(Path(input_dir).glob("*.png")) + self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.detector = detector + + assert len(self.input_paths) == len(self.target_paths), \ + f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" + + print(f"Found {len(self.input_paths)} image pairs") + print(f"Extracting {patches_per_image} patches per image using {detector} detector") + print(f"Total patches: {len(self.input_paths) * patches_per_image}") + + def __len__(self): + return len(self.input_paths) * self.patches_per_image + + def _detect_salient_points(self, img_array): + """Detect salient points on original image.""" + gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + h, w = gray.shape + half_patch = self.patch_size // 2 + + corners = None + if self.detector == 'harris': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch) + elif self.detector == 'fast': + fast = cv2.FastFeatureDetector_create(threshold=20) + keypoints = fast.detect(gray, None) + corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + elif self.detector == 'shi-tomasi': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch, + useHarrisDetector=False) + elif self.detector == 'gradient': + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gradient_mag = np.sqrt(grad_x**2 + grad_y**2) + threshold = np.percentile(gradient_mag, 95) + y_coords, x_coords = np.where(gradient_mag > threshold) + + if len(x_coords) > self.patches_per_image * 2: + indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) + x_coords = x_coords[indices] + y_coords = y_coords[indices] + + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + + # Fallback to random if no corners found + if corners is None or len(corners) == 0: + x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) + y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) + + # Filter valid corners + valid_corners = [] + for corner in corners: + x, y = int(corner[0][0]), int(corner[0][1]) + if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: + valid_corners.append((x, y)) + if len(valid_corners) >= self.patches_per_image: + break + + # Fill with random if not enough + while len(valid_corners) < self.patches_per_image: + x = np.random.randint(half_patch, w - half_patch) + y = np.random.randint(half_patch, h - half_patch) + valid_corners.append((x, y)) + + return valid_corners + + def __getitem__(self, idx): + img_idx = idx // self.patches_per_image + patch_idx = idx % self.patches_per_image + + # Load original images (no resize) + input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 + target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0 + + # Detect salient points on original image + salient_points = self._detect_salient_points(input_img) + cx, cy = salient_points[patch_idx] + + # Extract patch + half_patch = self.patch_size // 2 + y1, y2 = cy - half_patch, cy + half_patch + x1, x2 = cx - half_patch, cx + half_patch + + input_patch = input_img[y1:y2, x1:x2] + target_patch = target_img[y1:y2, x1:x2] + + # Compute static features for patch + static_feat = compute_static_features(input_patch.astype(np.float32)) + + # Convert to tensors (C, H, W) + static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) + target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) + + # Pad target to 4 channels (RGBA) + target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) + + return static_feat, target + + class ImagePairDataset(Dataset): - """Dataset of input/target image pairs.""" + """Dataset of input/target image pairs (full-image mode).""" def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) @@ -140,11 +253,19 @@ def train(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") - # Create dataset - target_size = (args.image_size, args.image_size) - dataset = ImagePairDataset(args.input, args.target, target_size=target_size) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})") + # Create dataset (patch-based or full-image) + if args.full_image: + print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") + target_size = (args.image_size, args.image_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + else: + print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") + dataset = PatchDataset(args.input, args.target, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + detector=args.detector) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Create model model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) @@ -206,12 +327,29 @@ def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') + + # Training mode + parser.add_argument('--full-image', action='store_true', + help='Use full-image mode (resize all images)') parser.add_argument('--image-size', type=int, default=256, - help='Resize images to this size (default: 256)') + help='Full-image mode: resize to this size (default: 256)') + + # Patch-based mode (default) + parser.add_argument('--patch-size', type=int, default=32, + help='Patch mode: patch size (default: 32)') + parser.add_argument('--patches-per-image', type=int, default=64, + help='Patch mode: patches per image (default: 64)') + parser.add_argument('--detector', type=str, default='harris', + choices=['harris', 'fast', 'shi-tomasi', 'gradient'], + help='Patch mode: salient point detector (default: harris)') + + # Model architecture parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], help='Kernel sizes for 3 layers (default: 1 3 5)') parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], help='Output channels for 3 layers (default: 16 8 4)') + + # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') |
