#!/usr/bin/env python3 """CNN v2 Training Script - Uniform 12D→4D Architecture Architecture: - Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias - Input RGBD (4D): original image mip 0 - All layers: input RGBD (4D) + static (8D) = 12D → 4 channels - Per-layer kernel sizes (e.g., 1×1, 3×3, 5×5) - Uniform layer structure with bias=False (bias in static features) """ import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from PIL import Image import time import cv2 def compute_static_features(rgb, depth=None, mip_level=0): """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] Note: p0-p3 are parametric features generated from specified mip level """ h, w = rgb.shape[:2] # Generate mip level for p0-p3 if mip_level > 0: # Downsample to mip level mip_rgb = rgb.copy() for _ in range(mip_level): mip_rgb = cv2.pyrDown(mip_rgb) # Upsample back to original size for _ in range(mip_level): mip_rgb = cv2.pyrUp(mip_rgb) # Crop/pad to exact original size if needed if mip_rgb.shape[:2] != (h, w): mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) else: mip_rgb = rgb # Parametric features (p0-p3) from mip level p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) # Bias dimension (always 1.0) - replaces Conv2d bias parameter bias = np.ones((h, w), dtype=np.float32) # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): """CNN v2 - Uniform 12D→4D Architecture All layers: input RGBD (4D) + static (8D) = 12D → 4 channels Per-layer kernel sizes supported (e.g., [1, 3, 5]) Uses bias=False (bias integrated in static features as 1.0) TODO: Add quantization-aware training (QAT) for 8-bit weights - Use torch.quantization.QuantStub/DeQuantStub - Train with fake quantization to adapt to 8-bit precision - Target: ~1.3 KB weights (vs 2.6 KB with f16) """ def __init__(self, kernel_sizes, num_layers=3): super().__init__() if isinstance(kernel_sizes, int): kernel_sizes = [kernel_sizes] * num_layers assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" self.kernel_sizes = kernel_sizes self.num_layers = num_layers self.layers = nn.ModuleList() # All layers: 12D input (4 RGBD + 8 static) → 4D output for kernel_size in kernel_sizes: self.layers.append( nn.Conv2d(12, 4, kernel_size=kernel_size, padding=kernel_size//2, bias=False) ) def forward(self, input_rgbd, static_features): """Forward pass with uniform 12D→4D layers. Args: input_rgbd: (B, 4, H, W) input image RGBD (mip 0) static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ # Layer 0: input RGBD (4D) + static (8D) = 12D x = torch.cat([input_rgbd, static_features], dim=1) x = self.layers[0](x) x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 # Layer 1+: previous (4D) + static (8D) = 12D for i in range(1, self.num_layers): x_input = torch.cat([x, static_features], dim=1) x = self.layers[i](x_input) if i < self.num_layers - 1: x = F.relu(x) else: x = torch.clamp(x, 0, 1) # Final output [0,1] return x class PatchDataset(Dataset): """Patch-based dataset extracting salient regions from images.""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, detector='harris', mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" print(f"Found {len(self.input_paths)} image pairs") print(f"Extracting {patches_per_image} patches per image using {detector} detector") print(f"Total patches: {len(self.input_paths) * patches_per_image}") def __len__(self): return len(self.input_paths) * self.patches_per_image def _detect_salient_points(self, img_array): """Detect salient points on original image. TODO: Add random sampling to training vectors - In addition to salient points, incorporate randomly-located samples - Default: 10% random samples, 90% salient points - Prevents overfitting to only high-gradient regions - Improves generalization across entire image - Configurable via --random-sample-percent parameter """ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) h, w = gray.shape half_patch = self.patch_size // 2 corners = None if self.detector == 'harris': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch) elif self.detector == 'fast': fast = cv2.FastFeatureDetector_create(threshold=20) keypoints = fast.detect(gray, None) corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None elif self.detector == 'shi-tomasi': corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, qualityLevel=0.01, minDistance=half_patch, useHarrisDetector=False) elif self.detector == 'gradient': grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_mag = np.sqrt(grad_x**2 + grad_y**2) threshold = np.percentile(gradient_mag, 95) y_coords, x_coords = np.where(gradient_mag > threshold) if len(x_coords) > self.patches_per_image * 2: indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) x_coords = x_coords[indices] y_coords = y_coords[indices] corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None # Fallback to random if no corners found if corners is None or len(corners) == 0: x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) corners = corners.reshape(-1, 1, 2) # Filter valid corners valid_corners = [] for corner in corners: x, y = int(corner[0][0]), int(corner[0][1]) if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: valid_corners.append((x, y)) if len(valid_corners) >= self.patches_per_image: break # Fill with random if not enough while len(valid_corners) < self.patches_per_image: x = np.random.randint(half_patch, w - half_patch) y = np.random.randint(half_patch, h - half_patch) valid_corners.append((x, y)) return valid_corners def __getitem__(self, idx): img_idx = idx // self.patches_per_image patch_idx = idx % self.patches_per_image # Load original images (no resize) input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 target_pil = Image.open(self.target_paths[img_idx]) target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha # Detect salient points on original image (use RGB only) salient_points = self._detect_salient_points(input_img) cx, cy = salient_points[patch_idx] # Extract patch half_patch = self.patch_size // 2 y1, y2 = cy - half_patch, cy + half_patch x1, x2 = cx - half_patch, cx + half_patch input_patch = input_img[y1:y2, x1:x2] target_patch = target_img[y1:y2, x1:x2] # RGBA # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image return input_rgbd, static_feat, target class ImagePairDataset(Dataset): """Dataset of input/target image pairs (full-image mode).""" def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" def __len__(self): return len(self.input_paths) def __getitem__(self, idx): # Load and resize images to fixed size input_pil = Image.open(self.input_paths[idx]).convert('RGB') target_pil = Image.open(self.target_paths[idx]) # Resize to target size input_pil = input_pil.resize(self.target_size, Image.LANCZOS) target_pil = target_pil.resize(self.target_size, Image.LANCZOS) input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha # Compute static features static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) # Convert to tensors (C, H, W) input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image return input_rgbd, static_feat, target def train(args): """Train CNN v2 model.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") # Create dataset (patch-based or full-image) if args.full_image: print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") target_size = (args.image_size, args.image_size) dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) else: print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, detector=args.detector, mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1: kernel_sizes = kernel_sizes * args.num_layers # Create model model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device) total_params = sum(p.numel() for p in model.parameters()) kernel_desc = ','.join(map(str, kernel_sizes)) print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights") print(f"Using mip level {args.mip_level} for p0-p3 features") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) criterion = nn.MSELoss() # Training loop print(f"\nTraining for {args.epochs} epochs...") start_time = time.time() for epoch in range(1, args.epochs + 1): model.train() epoch_loss = 0.0 for input_rgbd, static_feat, target in dataloader: input_rgbd = input_rgbd.to(device) static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() output = model(input_rgbd, static_feat) loss = criterion(output, target) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) # Print loss at every epoch (overwrite line with \r) elapsed = time.time() - start_time print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: print() # Newline before checkpoint message checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { 'kernel_sizes': kernel_sizes, 'num_layers': args.num_layers, 'mip_level': args.mip_level, 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") return model def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') # Training mode parser.add_argument('--full-image', action='store_true', help='Use full-image mode (resize all images)') parser.add_argument('--image-size', type=int, default=256, help='Full-image mode: resize to this size (default: 256)') # Patch-based mode (default) parser.add_argument('--patch-size', type=int, default=32, help='Patch mode: patch size (default: 32)') parser.add_argument('--patches-per-image', type=int, default=64, help='Patch mode: patches per image (default: 64)') parser.add_argument('--detector', type=str, default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], help='Patch mode: salient point detector (default: harris)') # TODO: Add --random-sample-percent parameter (default: 10) # Mix salient points with random samples for better generalization # Model architecture parser.add_argument('--kernel-sizes', type=str, default='3', help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)') parser.add_argument('--num-layers', type=int, default=3, help='Number of CNN layers (default: 3)') parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3], help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', help='Checkpoint directory') parser.add_argument('--checkpoint-every', type=int, default=1000, help='Save checkpoint every N epochs (0 = disable)') args = parser.parse_args() train(args) if __name__ == '__main__': main()