diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 22:51:46 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 22:51:46 +0100 |
| commit | 58f276378735e0b51f4d1517a844357e45e376a7 (patch) | |
| tree | 6e4516920837ab3748b5ee54494218d6c844ee07 /training/train_cnn.py | |
| parent | 5b31395fbfffabdd1cc9b452eb12d9dd63110a6d (diff) | |
feat: Add salient-point patch extraction for CNN training
Preserve natural pixel scale by extracting patches at salient points
instead of resizing entire images.
Features:
- Multiple detectors: Harris (default), FAST, Shi-Tomasi, gradient
- Configurable patch size (e.g., 32×32) and patches per image
- Automatic fallback to random patches if insufficient features
Usage:
# Patch-based training (preserves scale)
python3 train_cnn.py --input dir/ --target dir/ --patch-size 32 --patches-per-image 64 --detector harris
# Original resize mode (if --patch-size omitted)
python3 train_cnn.py --input dir/ --target dir/
Arguments:
--patch-size: Patch dimension (e.g., 32 for 32×32 patches)
--patches-per-image: Number of patches to extract per image (default: 64)
--detector: harris|fast|shi-tomasi|gradient (default: harris)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 165 |
1 files changed, 160 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 2b60d15..17cceb3 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -22,6 +22,8 @@ import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image +import numpy as np +import cv2 import os import sys import argparse @@ -78,6 +80,143 @@ class ImagePairDataset(Dataset): return input_img, target_img +class PatchDataset(Dataset): + """Dataset for extracting salient patches from image pairs""" + + def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, + detector='harris', transform=None): + self.input_dir = input_dir + self.target_dir = target_dir + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.detector = detector + self.transform = transform + + # Find all image pairs + input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] + self.image_pairs = [] + + for pattern in input_patterns: + input_files = glob.glob(os.path.join(input_dir, pattern)) + for input_path in input_files: + filename = os.path.basename(input_path) + target_path = None + for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: + base_name = os.path.splitext(filename)[0] + candidate = os.path.join(target_dir, f"{base_name}.{ext}") + if os.path.exists(candidate): + target_path = candidate + break + + if target_path: + self.image_pairs.append((input_path, target_path)) + + if not self.image_pairs: + raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") + + print(f"Found {len(self.image_pairs)} image pairs") + print(f"Extracting {patches_per_image} patches per image using {detector} detector") + print(f"Total patches: {len(self.image_pairs) * patches_per_image}") + + def __len__(self): + return len(self.image_pairs) * self.patches_per_image + + def _detect_salient_points(self, img_array): + """Detect salient points using specified detector""" + gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) + h, w = gray.shape + half_patch = self.patch_size // 2 + + if self.detector == 'harris': + # Harris corner detection + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch) + elif self.detector == 'fast': + # FAST feature detection + fast = cv2.FastFeatureDetector_create(threshold=20) + keypoints = fast.detect(gray, None) + corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + elif self.detector == 'shi-tomasi': + # Shi-Tomasi corner detection (goodFeaturesToTrack with different params) + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch, + useHarrisDetector=False) + elif self.detector == 'gradient': + # High-gradient regions + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gradient_mag = np.sqrt(grad_x**2 + grad_y**2) + + # Find top gradient locations + threshold = np.percentile(gradient_mag, 95) + y_coords, x_coords = np.where(gradient_mag > threshold) + + if len(x_coords) > self.patches_per_image * 2: + indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) + x_coords = x_coords[indices] + y_coords = y_coords[indices] + + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + else: + raise ValueError(f"Unknown detector: {self.detector}") + + # Fallback to random if no corners found + if corners is None or len(corners) == 0: + x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) + y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) + + # Filter valid corners (within bounds) + valid_corners = [] + for corner in corners: + x, y = int(corner[0][0]), int(corner[0][1]) + if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: + valid_corners.append((x, y)) + if len(valid_corners) >= self.patches_per_image: + break + + # Fill with random if not enough + while len(valid_corners) < self.patches_per_image: + x = np.random.randint(half_patch, w - half_patch) + y = np.random.randint(half_patch, h - half_patch) + valid_corners.append((x, y)) + + return valid_corners + + def __getitem__(self, idx): + img_idx = idx // self.patches_per_image + patch_idx = idx % self.patches_per_image + + input_path, target_path = self.image_pairs[img_idx] + + # Load images + input_img = Image.open(input_path).convert('RGBA') + target_img = Image.open(target_path).convert('RGB') + + # Detect salient points (use input image for detection) + input_array = np.array(input_img)[:, :, :3] # Use RGB for detection + corners = self._detect_salient_points(input_array) + + # Extract patch at specified index + x, y = corners[patch_idx] + half_patch = self.patch_size // 2 + + # Crop patches + input_patch = input_img.crop((x - half_patch, y - half_patch, + x + half_patch, y + half_patch)) + target_patch = target_img.crop((x - half_patch, y - half_patch, + x + half_patch, y + half_patch)) + + if self.transform: + input_patch = self.transform(input_patch) + target_patch = self.transform(target_patch) + + return input_patch, target_patch + + class SimpleCNN(nn.Module): """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" @@ -362,12 +501,24 @@ def train(args): print(f"Using device: {device}") # Prepare dataset - transform = transforms.Compose([ - transforms.Resize((256, 256)), - transforms.ToTensor(), - ]) + if args.patch_size: + # Patch-based training (preserves natural scale) + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + dataset = PatchDataset(args.input, args.target, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + detector=args.detector, + transform=transform) + else: + # Full-image training (resize mode) + transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + ]) + dataset = ImagePairDataset(args.input, args.target, transform=transform) - dataset = ImagePairDataset(args.input, args.target, transform=transform) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes @@ -578,6 +729,10 @@ def main(): parser.add_argument('--resume', help='Resume from checkpoint file') parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') + parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)') + parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)') + parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], + help='Salient point detector for patch extraction (default: harris)') args = parser.parse_args() |
