summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 22:51:46 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 22:51:46 +0100
commit58f276378735e0b51f4d1517a844357e45e376a7 (patch)
tree6e4516920837ab3748b5ee54494218d6c844ee07 /training/train_cnn.py
parent5b31395fbfffabdd1cc9b452eb12d9dd63110a6d (diff)
feat: Add salient-point patch extraction for CNN training
Preserve natural pixel scale by extracting patches at salient points instead of resizing entire images. Features: - Multiple detectors: Harris (default), FAST, Shi-Tomasi, gradient - Configurable patch size (e.g., 32×32) and patches per image - Automatic fallback to random patches if insufficient features Usage: # Patch-based training (preserves scale) python3 train_cnn.py --input dir/ --target dir/ --patch-size 32 --patches-per-image 64 --detector harris # Original resize mode (if --patch-size omitted) python3 train_cnn.py --input dir/ --target dir/ Arguments: --patch-size: Patch dimension (e.g., 32 for 32×32 patches) --patches-per-image: Number of patches to extract per image (default: 64) --detector: harris|fast|shi-tomasi|gradient (default: harris) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py165
1 files changed, 160 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 2b60d15..17cceb3 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -22,6 +22,8 @@ import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
+import numpy as np
+import cv2
import os
import sys
import argparse
@@ -78,6 +80,143 @@ class ImagePairDataset(Dataset):
return input_img, target_img
+class PatchDataset(Dataset):
+ """Dataset for extracting salient patches from image pairs"""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris', transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+ self.transform = transform
+
+ # Find all image pairs
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.image_pairs) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.image_pairs) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points using specified detector"""
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ if self.detector == 'harris':
+ # Harris corner detection
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ # FAST feature detection
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ # Shi-Tomasi corner detection (goodFeaturesToTrack with different params)
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ # High-gradient regions
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+
+ # Find top gradient locations
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ else:
+ raise ValueError(f"Unknown detector: {self.detector}")
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners (within bounds)
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ input_path, target_path = self.image_pairs[img_idx]
+
+ # Load images
+ input_img = Image.open(input_path).convert('RGBA')
+ target_img = Image.open(target_path).convert('RGB')
+
+ # Detect salient points (use input image for detection)
+ input_array = np.array(input_img)[:, :, :3] # Use RGB for detection
+ corners = self._detect_salient_points(input_array)
+
+ # Extract patch at specified index
+ x, y = corners[patch_idx]
+ half_patch = self.patch_size // 2
+
+ # Crop patches
+ input_patch = input_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+ target_patch = target_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+
+ if self.transform:
+ input_patch = self.transform(input_patch)
+ target_patch = self.transform(target_patch)
+
+ return input_patch, target_patch
+
+
class SimpleCNN(nn.Module):
"""CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
@@ -362,12 +501,24 @@ def train(args):
print(f"Using device: {device}")
# Prepare dataset
- transform = transforms.Compose([
- transforms.Resize((256, 256)),
- transforms.ToTensor(),
- ])
+ if args.patch_size:
+ # Patch-based training (preserves natural scale)
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector,
+ transform=transform)
+ else:
+ # Full-image training (resize mode)
+ transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ ])
+ dataset = ImagePairDataset(args.input, args.target, transform=transform)
- dataset = ImagePairDataset(args.input, args.target, transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# Parse kernel sizes
@@ -578,6 +729,10 @@ def main():
parser.add_argument('--resume', help='Resume from checkpoint file')
parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
+ parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)')
+ parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)')
+ parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Salient point detector for patch extraction (default: harris)')
args = parser.parse_args()