#!/usr/bin/env python3 """CNN v2 Training Script - Parametric Static Features Trains a multi-layer CNN with 7D static feature input: - RGBD (4D) - UV coordinates (2D) - sin(10*uv.x) position encoding (1D) - Bias dimension (1D, always 1.0) """ import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from PIL import Image import time def compute_static_features(rgb, depth=None): """Generate 7D static features + bias dimension. Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional Returns: (H, W, 8) static features tensor """ h, w = rgb.shape[:2] # RGBD channels r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] d = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32) # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) # Bias dimension (always 1.0) bias = np.ones((h, w), dtype=np.float32) # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): """CNN v2 with parametric static features.""" def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): super().__init__() self.kernels = kernels self.channels = channels # Input layer: 8D (7 features + bias) → channels[0] self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], padding=kernels[0]//2, bias=False) # Inner layers: (8 + C_prev) → C_next in_ch_1 = 8 + channels[0] self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], padding=kernels[1]//2, bias=False) # Output layer: (8 + C_last) → 4 (RGBA) in_ch_2 = 8 + channels[1] self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], padding=kernels[2]//2, bias=False) def forward(self, static_features): """Forward pass with static feature concatenation. Args: static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ # Layer 0: Use full 8D static features x0 = self.layer0(static_features) x0 = F.relu(x0) # Layer 1: Concatenate static + layer0 output x1_input = torch.cat([static_features, x0], dim=1) x1 = self.layer1(x1_input) x1 = F.relu(x1) # Layer 2: Concatenate static + layer1 output x2_input = torch.cat([static_features, x1], dim=1) output = self.layer2(x2_input) return torch.sigmoid(output) class ImagePairDataset(Dataset): """Dataset of input/target image pairs.""" def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" def __len__(self): return len(self.input_paths) def __getitem__(self, idx): # Load and resize images to fixed size input_pil = Image.open(self.input_paths[idx]).convert('RGB') target_pil = Image.open(self.target_paths[idx]).convert('RGB') # Resize to target size input_pil = input_pil.resize(self.target_size, Image.LANCZOS) target_pil = target_pil.resize(self.target_size, Image.LANCZOS) input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil) / 255.0 # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) # Convert to tensors (C, H, W) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) return static_feat, target def train(args): """Train CNN v2 model.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") # Create dataset target_size = (args.image_size, args.image_size) dataset = ImagePairDataset(args.input, args.target, target_size=target_size) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})") # Create model model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) criterion = nn.MSELoss() # Training loop print(f"\nTraining for {args.epochs} epochs...") start_time = time.time() for epoch in range(1, args.epochs + 1): model.train() epoch_loss = 0.0 for static_feat, target in dataloader: static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() output = model(static_feat) loss = criterion(output, target) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) if epoch % 100 == 0 or epoch == 1: elapsed = time.time() - start_time print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s") # Save checkpoint if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { 'kernels': args.kernel_sizes, 'channels': args.channels, 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") return model def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') parser.add_argument('--image-size', type=int, default=256, help='Resize images to this size (default: 256)') parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], help='Kernel sizes for 3 layers (default: 1 3 5)') parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], help='Output channels for 3 layers (default: 16 8 4)') parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') parser.add_argument('--batch-size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--checkpoint-dir', type=str, default='checkpoints', help='Checkpoint directory') parser.add_argument('--checkpoint-every', type=int, default=1000, help='Save checkpoint every N epochs (0 = disable)') args = parser.parse_args() train(args) if __name__ == '__main__': main()