#!/usr/bin/env python3 """CNN v2 Training Script - Parametric Static Features Trains a multi-layer CNN with 7D static feature input: - RGBD (4D) - UV coordinates (2D) - sin(10*uv.x) position encoding (1D) - Bias dimension (1D, always 1.0) """ import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from PIL import Image import time import cv2 def compute_static_features(rgb, depth=None): """Generate 7D static features + bias dimension. Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional Returns: (H, W, 8) static features tensor """ h, w = rgb.shape[:2] # RGBD channels r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] d = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) # Bias dimension (always 1.0) bias = np.ones((h, w), dtype=np.float32) # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): """CNN v2 with parametric static features. TODO: Add quantization-aware training (QAT) for 8-bit weights - Use torch.quantization.QuantStub/DeQuantStub - Train with fake quantization to adapt to 8-bit precision - Target: ~1.6 KB weights (vs 3.2 KB with f16) """ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): super().__init__() self.kernels = kernels self.channels = channels # Input layer: 8D (7 features + bias) → channels[0] self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], padding=kernels[0]//2, bias=False) # Inner layers: (8 + C_prev) → C_next in_ch_1 = 8 + channels[0] self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], padding=kernels[1]//2, bias=False) # Output layer: (8 + C_last) → 4 (RGBA) in_ch_2 = 8 + channels[1] self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], padding=kernels[2]//2, bias=False) def forward(self, static_features): """Forward pass with static feature concatenation. Args: static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ # Layer 0: Use full 8D static features x0 = self.layer0(static_features) x0 = F.relu(x0) # Layer 1: Concatenate static + layer0 output x1_input = torch.cat([static_features, x0], dim=1) x1 = self.layer1(x1_input) x1 = F.relu(x1) # Layer 2: Concatenate static + layer1 output x2_input = torch.cat([static_features, x1], dim=1) output = self.layer2(x2_input) return torch.sigmoid(output) class PatchDataset(Dataset): """Patch-based dataset extracting salient regions from images.""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, detector='harris'): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" print(f"Found {len(self.input_paths)} image pairs") print(f"Extracting {patches_per_image} patches per image using {detector} detector") print(f"Total patches: {len(self.input_paths) * patches_per_image}") def __len__(self): return len(self.input_paths) * self.patches_per_image def _detect_salient_points(self, img_array): """Detect salient points on original image. TODO: Add random sampling to training vectors - In addition to salient points, incorporate randomly-located samples - Default: 10% random samples, 90% salient points - Prevents overfitting to only high-gradient regions - Improves generalization across entire image - Configurable via --random-sample-percent parameter """ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) h, w = gray.shape half_patch = self.patch_size // 2 corners = None if self.detector == 'harris': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch) elif self.detector == 'fast': fast = cv2.FastFeatureDetector_create(threshold=20) keypoints = fast.detect(gray, None) corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None elif self.detector == 'shi-tomasi': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch, useHarrisDetector=False) elif self.detector == 'gradient': grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_mag = np.sqrt(grad_x**2 + grad_y**2) threshold = np.percentile(gradient_mag, 95) y_coords, x_coords = np.where(gradient_mag > threshold) if len(x_coords) > self.patches_per_image * 2: indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) x_coords = x_coords[indices] y_coords = y_coords[indices] corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None # Fallback to random if no corners found if corners is None or len(corners) == 0: x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) # Filter valid corners valid_corners = [] for corner in corners: x, y = int(corner[0][0]), int(corner[0][1]) if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: valid_corners.append((x, y)) if len(valid_corners) >= self.patches_per_image: break # Fill with random if not enough while len(valid_corners) < self.patches_per_image: x = np.random.randint(half_patch, w - half_patch) y = np.random.randint(half_patch, h - half_patch) valid_corners.append((x, y)) return valid_corners def __getitem__(self, idx): img_idx = idx // self.patches_per_image patch_idx = idx % self.patches_per_image # Load original images (no resize) input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0 # Detect salient points on original image salient_points = self._detect_salient_points(input_img) cx, cy = salient_points[patch_idx] # Extract patch half_patch = self.patch_size // 2 y1, y2 = cy - half_patch, cy + half_patch x1, x2 = cx - half_patch, cx + half_patch input_patch = input_img[y1:y2, x1:x2] target_patch = target_img[y1:y2, x1:x2] # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32)) # Convert to tensors (C, H, W) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) return static_feat, target class ImagePairDataset(Dataset): """Dataset of input/target image pairs (full-image mode).""" def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" def __len__(self): return len(self.input_paths) def __getitem__(self, idx): # Load and resize images to fixed size input_pil = Image.open(self.input_paths[idx]).convert('RGB') target_pil = Image.open(self.target_paths[idx]).convert('RGB') # Resize to target size input_pil = input_pil.resize(self.target_size, Image.LANCZOS) target_pil = target_pil.resize(self.target_size, Image.LANCZOS) input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil) / 255.0 # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) # Convert to tensors (C, H, W) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) return static_feat, target def train(args): """Train CNN v2 model.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") # Create dataset (patch-based or full-image) if args.full_image: print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") target_size = (args.image_size, args.image_size) dataset = ImagePairDataset(args.input, args.target, target_size=target_size) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) else: print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, detector=args.detector) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Create model model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) criterion = nn.MSELoss() # Training loop print(f"\nTraining for {args.epochs} epochs...") start_time = time.time() for epoch in range(1, args.epochs + 1): model.train() epoch_loss = 0.0 for static_feat, target in dataloader: static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() output = model(static_feat) loss = criterion(output, target) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) # Print loss at every epoch (overwrite line with \r) elapsed = time.time() - start_time print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: print() # Newline before checkpoint message checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { 'kernels': args.kernel_sizes, 'channels': args.channels, 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") return model def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') # Training mode parser.add_argument('--full-image', action='store_true', help='Use full-image mode (resize all images)') parser.add_argument('--image-size', type=int, default=256, help='Full-image mode: resize to this size (default: 256)') # Patch-based mode (default) parser.add_argument('--patch-size', type=int, default=32, help='Patch mode: patch size (default: 32)') parser.add_argument('--patches-per-image', type=int, default=64, help='Patch mode: patches per image (default: 64)') parser.add_argument('--detector', type=str, default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Patch mode: salient point detector (default: harris)') # TODO: Add --random-sample-percent parameter (default: 10) # Mix salient points with random samples for better generalization # Model architecture parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], help='Kernel sizes for 3 layers (default: 1 3 5)') parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], help='Output channels for 3 layers (default: 16 8 4)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', help='Checkpoint directory') parser.add_argument('--checkpoint-every', type=int, default=1000, help='Save checkpoint every N epochs (0 = disable)') args = parser.parse_args() train(args) if __name__ == '__main__': main()