#!/usr/bin/env python3 """CNN v2 Training Script - Uniform 12D→4D Architecture Architecture: - Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias - Input RGBD (4D): original image mip 0 - All layers: input RGBD (4D) + static (8D) = 12D → 4 channels - Uniform layer structure with bias=False (bias in static features) """ import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from PIL import Image import time import cv2 def compute_static_features(rgb, depth=None): """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] Note: p0-p3 are parametric features (can be mips, gradients, etc.) For training, we use RGBD as default, but could use mip1/2 """ h, w = rgb.shape[:2] # Parametric features (p0-p3) - using RGBD as default # TODO: Experiment with mip1 grayscale, gradients, etc. p0 = rgb[:, :, 0].astype(np.float32) p1 = rgb[:, :, 1].astype(np.float32) p2 = rgb[:, :, 2].astype(np.float32) p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) # Bias dimension (always 1.0) - replaces Conv2d bias parameter bias = np.ones((h, w), dtype=np.float32) # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): """CNN v2 - Uniform 12D→4D Architecture All layers: input RGBD (4D) + static (8D) = 12D → 4 channels Uses bias=False (bias integrated in static features as 1.0) TODO: Add quantization-aware training (QAT) for 8-bit weights - Use torch.quantization.QuantStub/DeQuantStub - Train with fake quantization to adapt to 8-bit precision - Target: ~1.3 KB weights (vs 2.6 KB with f16) """ def __init__(self, kernel_size=3, num_layers=3): super().__init__() self.kernel_size = kernel_size self.num_layers = num_layers self.layers = nn.ModuleList() # All layers: 12D input (4 RGBD + 8 static) → 4D output for _ in range(num_layers): self.layers.append( nn.Conv2d(12, 4, kernel_size=kernel_size, padding=kernel_size//2, bias=False) ) def forward(self, input_rgbd, static_features): """Forward pass with uniform 12D→4D layers. Args: input_rgbd: (B, 4, H, W) input image RGBD (mip 0) static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ # Layer 0: input RGBD (4D) + static (8D) = 12D x = torch.cat([input_rgbd, static_features], dim=1) x = self.layers[0](x) x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 # Layer 1+: previous (4D) + static (8D) = 12D for i in range(1, self.num_layers): x_input = torch.cat([x, static_features], dim=1) x = self.layers[i](x_input) if i < self.num_layers - 1: x = F.relu(x) else: x = torch.clamp(x, 0, 1) # Final output [0,1] return x class PatchDataset(Dataset): """Patch-based dataset extracting salient regions from images.""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, detector='harris'): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" print(f"Found {len(self.input_paths)} image pairs") print(f"Extracting {patches_per_image} patches per image using {detector} detector") print(f"Total patches: {len(self.input_paths) * patches_per_image}") def __len__(self): return len(self.input_paths) * self.patches_per_image def _detect_salient_points(self, img_array): """Detect salient points on original image. TODO: Add random sampling to training vectors - In addition to salient points, incorporate randomly-located samples - Default: 10% random samples, 90% salient points - Prevents overfitting to only high-gradient regions - Improves generalization across entire image - Configurable via --random-sample-percent parameter """ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) h, w = gray.shape half_patch = self.patch_size // 2 corners = None if self.detector == 'harris': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch) elif self.detector == 'fast': fast = cv2.FastFeatureDetector_create(threshold=20) keypoints = fast.detect(gray, None) corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None elif self.detector == 'shi-tomasi': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch, useHarrisDetector=False) elif self.detector == 'gradient': grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_mag = np.sqrt(grad_x**2 + grad_y**2) threshold = np.percentile(gradient_mag, 95) y_coords, x_coords = np.where(gradient_mag > threshold) if len(x_coords) > self.patches_per_image * 2: indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) x_coords = x_coords[indices] y_coords = y_coords[indices] corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None # Fallback to random if no corners found if corners is None or len(corners) == 0: x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) # Filter valid corners valid_corners = [] for corner in corners: x, y = int(corner[0][0]), int(corner[0][1]) if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: valid_corners.append((x, y)) if len(valid_corners) >= self.patches_per_image: break # Fill with random if not enough while len(valid_corners) < self.patches_per_image: x = np.random.randint(half_patch, w - half_patch) y = np.random.randint(half_patch, h - half_patch) valid_corners.append((x, y)) return valid_corners def __getitem__(self, idx): img_idx = idx // self.patches_per_image patch_idx = idx % self.patches_per_image # Load original images (no resize) input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0 # Detect salient points on original image salient_points = self._detect_salient_points(input_img) cx, cy = salient_points[patch_idx] # Extract patch half_patch = self.patch_size // 2 y1, y2 = cy - half_patch, cy + half_patch x1, x2 = cx - half_patch, cx + half_patch input_patch = input_img[y1:y2, x1:x2] target_patch = target_img[y1:y2, x1:x2] # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32)) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) return input_rgbd, static_feat, target class ImagePairDataset(Dataset): """Dataset of input/target image pairs (full-image mode).""" def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" def __len__(self): return len(self.input_paths) def __getitem__(self, idx): # Load and resize images to fixed size input_pil = Image.open(self.input_paths[idx]).convert('RGB') target_pil = Image.open(self.target_paths[idx]).convert('RGB') # Resize to target size input_pil = input_pil.resize(self.target_size, Image.LANCZOS) target_pil = target_pil.resize(self.target_size, Image.LANCZOS) input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil) / 255.0 # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) return input_rgbd, static_feat, target def train(args): """Train CNN v2 model.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") # Create dataset (patch-based or full-image) if args.full_image: print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") target_size = (args.image_size, args.image_size) dataset = ImagePairDataset(args.input, args.target, target_size=target_size) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) else: print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, detector=args.detector) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Create model model = CNNv2(kernel_size=args.kernel_size, num_layers=args.num_layers).to(device) total_params = sum(p.numel() for p in model.parameters()) weights_per_layer = 12 * args.kernel_size * args.kernel_size * 4 print(f"Model: {args.num_layers} layers, {args.kernel_size}×{args.kernel_size} kernels, {total_params} weights ({weights_per_layer}/layer)") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) criterion = nn.MSELoss() # Training loop print(f"\nTraining for {args.epochs} epochs...") start_time = time.time() for epoch in range(1, args.epochs + 1): model.train() epoch_loss = 0.0 for input_rgbd, static_feat, target in dataloader: input_rgbd = input_rgbd.to(device) static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() output = model(input_rgbd, static_feat) loss = criterion(output, target) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) # Print loss at every epoch (overwrite line with \r) elapsed = time.time() - start_time print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: print() # Newline before checkpoint message checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { 'kernel_size': args.kernel_size, 'num_layers': args.num_layers, 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") return model def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') # Training mode parser.add_argument('--full-image', action='store_true', help='Use full-image mode (resize all images)') parser.add_argument('--image-size', type=int, default=256, help='Full-image mode: resize to this size (default: 256)') # Patch-based mode (default) parser.add_argument('--patch-size', type=int, default=32, help='Patch mode: patch size (default: 32)') parser.add_argument('--patches-per-image', type=int, default=64, help='Patch mode: patches per image (default: 64)') parser.add_argument('--detector', type=str, default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Patch mode: salient point detector (default: harris)') # TODO: Add --random-sample-percent parameter (default: 10) # Mix salient points with random samples for better generalization # Model architecture parser.add_argument('--kernel-size', type=int, default=3, help='Kernel size (uniform for all layers, default: 3)') parser.add_argument('--num-layers', type=int, default=3, help='Number of CNN layers (default: 3)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', help='Checkpoint directory') parser.add_argument('--checkpoint-every', type=int, default=1000, help='Save checkpoint every N epochs (0 = disable)') args = parser.parse_args() train(args) if __name__ == '__main__': main()