diff options
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 383 |
1 files changed, 383 insertions, 0 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py new file mode 100755 index 0000000..758b044 --- /dev/null +++ b/training/train_cnn_v2.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""CNN v2 Training Script - Parametric Static Features + +Trains a multi-layer CNN with 7D static feature input: +- RGBD (4D) +- UV coordinates (2D) +- sin(10*uv.x) position encoding (1D) +- Bias dimension (1D, always 1.0) +""" + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +from PIL import Image +import time +import cv2 + + +def compute_static_features(rgb, depth=None): + """Generate 7D static features + bias dimension. + + Args: + rgb: (H, W, 3) RGB image [0, 1] + depth: (H, W) depth map [0, 1], optional + + Returns: + (H, W, 8) static features tensor + """ + h, w = rgb.shape[:2] + + # RGBD channels + r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] + d = depth if depth is not None else np.zeros((h, w), dtype=np.float32) + + # UV coordinates (normalized [0, 1]) + uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) + uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) + + # Multi-frequency position encoding + sin10_x = np.sin(10.0 * uv_x).astype(np.float32) + + # Bias dimension (always 1.0) + bias = np.ones((h, w), dtype=np.float32) + + # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] + features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1) + return features + + +class CNNv2(nn.Module): + """CNN v2 with parametric static features. + + TODO: Add quantization-aware training (QAT) for 8-bit weights + - Use torch.quantization.QuantStub/DeQuantStub + - Train with fake quantization to adapt to 8-bit precision + - Target: ~1.6 KB weights (vs 3.2 KB with f16) + """ + + def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): + super().__init__() + self.kernels = kernels + self.channels = channels + + # Input layer: 8D (7 features + bias) → channels[0] + self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], + padding=kernels[0]//2, bias=False) + + # Inner layers: (8 + C_prev) → C_next + in_ch_1 = 8 + channels[0] + self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], + padding=kernels[1]//2, bias=False) + + # Output layer: (8 + C_last) → 4 (RGBA) + in_ch_2 = 8 + channels[1] + self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], + padding=kernels[2]//2, bias=False) + + def forward(self, static_features): + """Forward pass with static feature concatenation. + + Args: + static_features: (B, 8, H, W) static features + + Returns: + (B, 4, H, W) RGBA output [0, 1] + """ + # Layer 0: Use full 8D static features + x0 = self.layer0(static_features) + x0 = F.relu(x0) + + # Layer 1: Concatenate static + layer0 output + x1_input = torch.cat([static_features, x0], dim=1) + x1 = self.layer1(x1_input) + x1 = F.relu(x1) + + # Layer 2: Concatenate static + layer1 output + x2_input = torch.cat([static_features, x1], dim=1) + output = self.layer2(x2_input) + + return torch.sigmoid(output) + + +class PatchDataset(Dataset): + """Patch-based dataset extracting salient regions from images.""" + + def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, + detector='harris'): + self.input_paths = sorted(Path(input_dir).glob("*.png")) + self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.detector = detector + + assert len(self.input_paths) == len(self.target_paths), \ + f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" + + print(f"Found {len(self.input_paths)} image pairs") + print(f"Extracting {patches_per_image} patches per image using {detector} detector") + print(f"Total patches: {len(self.input_paths) * patches_per_image}") + + def __len__(self): + return len(self.input_paths) * self.patches_per_image + + def _detect_salient_points(self, img_array): + """Detect salient points on original image. + + TODO: Add random sampling to training vectors + - In addition to salient points, incorporate randomly-located samples + - Default: 10% random samples, 90% salient points + - Prevents overfitting to only high-gradient regions + - Improves generalization across entire image + - Configurable via --random-sample-percent parameter + """ + gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + h, w = gray.shape + half_patch = self.patch_size // 2 + + corners = None + if self.detector == 'harris': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch) + elif self.detector == 'fast': + fast = cv2.FastFeatureDetector_create(threshold=20) + keypoints = fast.detect(gray, None) + corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + elif self.detector == 'shi-tomasi': + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch, + useHarrisDetector=False) + elif self.detector == 'gradient': + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gradient_mag = np.sqrt(grad_x**2 + grad_y**2) + threshold = np.percentile(gradient_mag, 95) + y_coords, x_coords = np.where(gradient_mag > threshold) + + if len(x_coords) > self.patches_per_image * 2: + indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) + x_coords = x_coords[indices] + y_coords = y_coords[indices] + + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + + # Fallback to random if no corners found + if corners is None or len(corners) == 0: + x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) + y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) + + # Filter valid corners + valid_corners = [] + for corner in corners: + x, y = int(corner[0][0]), int(corner[0][1]) + if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: + valid_corners.append((x, y)) + if len(valid_corners) >= self.patches_per_image: + break + + # Fill with random if not enough + while len(valid_corners) < self.patches_per_image: + x = np.random.randint(half_patch, w - half_patch) + y = np.random.randint(half_patch, h - half_patch) + valid_corners.append((x, y)) + + return valid_corners + + def __getitem__(self, idx): + img_idx = idx // self.patches_per_image + patch_idx = idx % self.patches_per_image + + # Load original images (no resize) + input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0 + target_img = np.array(Image.open(self.target_paths[img_idx]).convert('RGB')) / 255.0 + + # Detect salient points on original image + salient_points = self._detect_salient_points(input_img) + cx, cy = salient_points[patch_idx] + + # Extract patch + half_patch = self.patch_size // 2 + y1, y2 = cy - half_patch, cy + half_patch + x1, x2 = cx - half_patch, cx + half_patch + + input_patch = input_img[y1:y2, x1:x2] + target_patch = target_img[y1:y2, x1:x2] + + # Compute static features for patch + static_feat = compute_static_features(input_patch.astype(np.float32)) + + # Convert to tensors (C, H, W) + static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) + target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) + + # Pad target to 4 channels (RGBA) + target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) + + return static_feat, target + + +class ImagePairDataset(Dataset): + """Dataset of input/target image pairs (full-image mode).""" + + def __init__(self, input_dir, target_dir, target_size=(256, 256)): + self.input_paths = sorted(Path(input_dir).glob("*.png")) + self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.target_size = target_size + assert len(self.input_paths) == len(self.target_paths), \ + f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" + + def __len__(self): + return len(self.input_paths) + + def __getitem__(self, idx): + # Load and resize images to fixed size + input_pil = Image.open(self.input_paths[idx]).convert('RGB') + target_pil = Image.open(self.target_paths[idx]).convert('RGB') + + # Resize to target size + input_pil = input_pil.resize(self.target_size, Image.LANCZOS) + target_pil = target_pil.resize(self.target_size, Image.LANCZOS) + + input_img = np.array(input_pil) / 255.0 + target_img = np.array(target_pil) / 255.0 + + # Compute static features + static_feat = compute_static_features(input_img.astype(np.float32)) + + # Convert to tensors (C, H, W) + static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) + target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) + + # Pad target to 4 channels (RGBA) + target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) + + return static_feat, target + + +def train(args): + """Train CNN v2 model.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Training on {device}") + + # Create dataset (patch-based or full-image) + if args.full_image: + print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") + target_size = (args.image_size, args.image_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + else: + print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") + dataset = PatchDataset(args.input, args.target, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + detector=args.detector) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Create model + model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) + total_params = sum(p.numel() for p in model.parameters()) + print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights") + + # Optimizer and loss + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + + # Training loop + print(f"\nTraining for {args.epochs} epochs...") + start_time = time.time() + + for epoch in range(1, args.epochs + 1): + model.train() + epoch_loss = 0.0 + + for static_feat, target in dataloader: + static_feat = static_feat.to(device) + target = target.to(device) + + optimizer.zero_grad() + output = model(static_feat) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(dataloader) + + # Print loss at every epoch (overwrite line with \r) + elapsed = time.time() - start_time + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True) + + # Save checkpoint + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() # Newline before checkpoint message + checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': avg_loss, + 'config': { + 'kernels': args.kernel_sizes, + 'channels': args.channels, + 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + } + }, checkpoint_path) + print(f" → Saved checkpoint: {checkpoint_path}") + + print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") + return model + + +def main(): + parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') + parser.add_argument('--input', type=str, required=True, help='Input images directory') + parser.add_argument('--target', type=str, required=True, help='Target images directory') + + # Training mode + parser.add_argument('--full-image', action='store_true', + help='Use full-image mode (resize all images)') + parser.add_argument('--image-size', type=int, default=256, + help='Full-image mode: resize to this size (default: 256)') + + # Patch-based mode (default) + parser.add_argument('--patch-size', type=int, default=32, + help='Patch mode: patch size (default: 32)') + parser.add_argument('--patches-per-image', type=int, default=64, + help='Patch mode: patches per image (default: 64)') + parser.add_argument('--detector', type=str, default='harris', + choices=['harris', 'fast', 'shi-tomasi', 'gradient'], + help='Patch mode: salient point detector (default: harris)') + # TODO: Add --random-sample-percent parameter (default: 10) + # Mix salient points with random samples for better generalization + + # Model architecture + parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], + help='Kernel sizes for 3 layers (default: 1 3 5)') + parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], + help='Output channels for 3 layers (default: 16 8 4)') + + # Training parameters + parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') + parser.add_argument('--batch-size', type=int, default=16, help='Batch size') + parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') + parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', + help='Checkpoint directory') + parser.add_argument('--checkpoint-every', type=int, default=1000, + help='Save checkpoint every N epochs (0 = disable)') + + args = parser.parse_args() + train(args) + + +if __name__ == '__main__': + main() |
