summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 11:50:52 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 11:50:52 +0100
commit7547e8ff4744339b92650b6ef3ff7405befe4beb (patch)
tree0388b064c6bb2fcb2346796f9d1134c5ed9214b5
parentc878631f24ddb7514dd4db3d7ace6a0a296d4157 (diff)
CNN v2: Patch-based training as default (like CNN v1)
Salient point detection on original images with patch extraction. Changes: - Added PatchDataset class (harris/fast/shi-tomasi/gradient detectors) - Detects salient points on ORIGINAL images (no resize) - Extracts 32×32 patches around salient points - Default: 64 patches/image, harris detector - Batch size: 16 (512 patches per batch) Training modes: 1. Patch-based (default): --patch-size 32 --patches-per-image 64 --detector harris 2. Full-image (option): --full-image --image-size 256 Benefits: - Focuses training on interesting regions - Handles variable image sizes naturally - Matches CNN v1 workflow - Better convergence with limited data (8 images → 512 patches) Script updated: - train_cnn_v2_full.sh: Patch-based by default - Configuration exposed for easy switching Example: ./scripts/train_cnn_v2_full.sh # Patch-based # Edit script: uncomment FULL_IMAGE for resize mode Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rwxr-xr-xscripts/train_cnn_v2_full.sh20
-rwxr-xr-xtraining/train_cnn_v2.py152
2 files changed, 161 insertions, 11 deletions
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
index 119b788..4ddd9ac 100755
--- a/scripts/train_cnn_v2_full.sh
+++ b/scripts/train_cnn_v2_full.sh
@@ -14,8 +14,17 @@ CHECKPOINT_DIR="checkpoints"
VALIDATION_DIR="validation_results"
EPOCHS=10000
CHECKPOINT_EVERY=500
-BATCH_SIZE=8
-IMAGE_SIZE=256
+BATCH_SIZE=16
+
+# Patch-based training (default)
+PATCH_SIZE=32
+PATCHES_PER_IMAGE=64
+DETECTOR="harris"
+
+# Full-image training (alternative - uncomment to use)
+# FULL_IMAGE="--full-image"
+# IMAGE_SIZE=256
+
KERNEL_SIZES="1 3 5"
CHANNELS="16 8 4"
@@ -31,13 +40,16 @@ echo "[1/4] Training CNN v2 model..."
python3 training/train_cnn_v2.py \
--input "$INPUT_DIR" \
--target "$TARGET_DIR" \
- --image-size $IMAGE_SIZE \
+ --patch-size $PATCH_SIZE \
+ --patches-per-image $PATCHES_PER_IMAGE \
+ --detector $DETECTOR \
--kernel-sizes $KERNEL_SIZES \
--channels $CHANNELS \
--epochs $EPOCHS \
--batch-size $BATCH_SIZE \
--checkpoint-dir "$CHECKPOINT_DIR" \
- --checkpoint-every $CHECKPOINT_EVERY
+ --checkpoint-every $CHECKPOINT_EVERY \
+ $FULL_IMAGE
if [ $? -ne 0 ]; then
echo "Error: Training failed"
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index e590b40..3ab1c0f 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -17,6 +17,7 @@ from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
import time
+import cv2
def compute_static_features(rgb, depth=None):
@@ -97,8 +98,120 @@ class CNNv2(nn.Module):
return torch.sigmoid(output)
+class PatchDataset(Dataset):
+ """Patch-based dataset extracting salient regions from images."""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris'):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ print(f"Found {len(self.input_paths)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.input_paths) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.input_paths) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points on original image."""
+ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ corners = None
+ if self.detector == 'harris':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ # Load original images (no resize)
+ input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
+ target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0
+
+ # Detect salient points on original image
+ salient_points = self._detect_salient_points(input_img)
+ cx, cy = salient_points[patch_idx]
+
+ # Extract patch
+ half_patch = self.patch_size // 2
+ y1, y2 = cy - half_patch, cy + half_patch
+ x1, x2 = cx - half_patch, cx + half_patch
+
+ input_patch = input_img[y1:y2, x1:x2]
+ target_patch = target_img[y1:y2, x1:x2]
+
+ # Compute static features for patch
+ static_feat = compute_static_features(input_patch.astype(np.float32))
+
+ # Convert to tensors (C, H, W)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1)
+
+ # Pad target to 4 channels (RGBA)
+ target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+
+ return static_feat, target
+
+
class ImagePairDataset(Dataset):
- """Dataset of input/target image pairs."""
+ """Dataset of input/target image pairs (full-image mode)."""
def __init__(self, input_dir, target_dir, target_size=(256, 256)):
self.input_paths = sorted(Path(input_dir).glob("*.png"))
@@ -140,11 +253,19 @@ def train(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on {device}")
- # Create dataset
- target_size = (args.image_size, args.image_size)
- dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})")
+ # Create dataset (patch-based or full-image)
+ if args.full_image:
+ print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
+ target_size = (args.image_size, args.image_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+ else:
+ print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# Create model
model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
@@ -206,12 +327,29 @@ def main():
parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
parser.add_argument('--input', type=str, required=True, help='Input images directory')
parser.add_argument('--target', type=str, required=True, help='Target images directory')
+
+ # Training mode
+ parser.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize all images)')
parser.add_argument('--image-size', type=int, default=256,
- help='Resize images to this size (default: 256)')
+ help='Full-image mode: resize to this size (default: 256)')
+
+ # Patch-based mode (default)
+ parser.add_argument('--patch-size', type=int, default=32,
+ help='Patch mode: patch size (default: 32)')
+ parser.add_argument('--patches-per-image', type=int, default=64,
+ help='Patch mode: patches per image (default: 64)')
+ parser.add_argument('--detector', type=str, default='harris',
+ choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Patch mode: salient point detector (default: harris)')
+
+ # Model architecture
parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
help='Kernel sizes for 3 layers (default: 1 3 5)')
parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],
help='Output channels for 3 layers (default: 16 8 4)')
+
+ # Training parameters
parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')