summaryrefslogtreecommitdiff
path: root/training/train_cnn_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn_v2.py')
-rwxr-xr-xtraining/train_cnn_v2.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
new file mode 100755
index 0000000..fe148b4
--- /dev/null
+++ b/training/train_cnn_v2.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python3
+"""CNN v2 Training Script - Parametric Static Features
+
+Trains a multi-layer CNN with 7D static feature input:
+- RGBD (4D)
+- UV coordinates (2D)
+- sin(10*uv.x) position encoding (1D)
+- Bias dimension (1D, always 1.0)
+"""
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+from PIL import Image
+import time
+
+
+def compute_static_features(rgb, depth=None):
+ """Generate 7D static features + bias dimension.
+
+ Args:
+ rgb: (H, W, 3) RGB image [0, 1]
+ depth: (H, W) depth map [0, 1], optional
+
+ Returns:
+ (H, W, 8) static features tensor
+ """
+ h, w = rgb.shape[:2]
+
+ # RGBD channels
+ r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
+ d = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+
+ # UV coordinates (normalized [0, 1])
+ uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
+ uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32)
+
+ # Multi-frequency position encoding
+ sin10_x = np.sin(10.0 * uv_x).astype(np.float32)
+
+ # Bias dimension (always 1.0)
+ bias = np.ones((h, w), dtype=np.float32)
+
+ # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias]
+ features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1)
+ return features
+
+
+class CNNv2(nn.Module):
+ """CNN v2 with parametric static features."""
+
+ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]):
+ super().__init__()
+ self.kernels = kernels
+ self.channels = channels
+
+ # Input layer: 8D (7 features + bias) → channels[0]
+ self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0],
+ padding=kernels[0]//2, bias=False)
+
+ # Inner layers: (8 + C_prev) → C_next
+ in_ch_1 = 8 + channels[0]
+ self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1],
+ padding=kernels[1]//2, bias=False)
+
+ # Output layer: (8 + C_last) → 4 (RGBA)
+ in_ch_2 = 8 + channels[1]
+ self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2],
+ padding=kernels[2]//2, bias=False)
+
+ def forward(self, static_features):
+ """Forward pass with static feature concatenation.
+
+ Args:
+ static_features: (B, 8, H, W) static features
+
+ Returns:
+ (B, 4, H, W) RGBA output [0, 1]
+ """
+ # Layer 0: Use full 8D static features
+ x0 = self.layer0(static_features)
+ x0 = F.relu(x0)
+
+ # Layer 1: Concatenate static + layer0 output
+ x1_input = torch.cat([static_features, x0], dim=1)
+ x1 = self.layer1(x1_input)
+ x1 = F.relu(x1)
+
+ # Layer 2: Concatenate static + layer1 output
+ x2_input = torch.cat([static_features, x1], dim=1)
+ output = self.layer2(x2_input)
+
+ return torch.sigmoid(output)
+
+
+class ImagePairDataset(Dataset):
+ """Dataset of input/target image pairs."""
+
+ def __init__(self, input_dir, target_dir):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ def __len__(self):
+ return len(self.input_paths)
+
+ def __getitem__(self, idx):
+ # Load images
+ input_img = np.array(Image.open(self.input_paths[idx]).convert('RGB')) / 255.0
+ target_img = np.array(Image.open(self.target_paths[idx]).convert('RGB')) / 255.0
+
+ # Compute static features
+ static_feat = compute_static_features(input_img.astype(np.float32))
+
+ # Convert to tensors (C, H, W)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1)
+
+ # Pad target to 4 channels (RGBA)
+ target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
+
+ return static_feat, target
+
+
+def train(args):
+ """Train CNN v2 model."""
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Training on {device}")
+
+ # Create dataset
+ dataset = ImagePairDataset(args.input, args.target)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+ print(f"Loaded {len(dataset)} image pairs")
+
+ # Create model
+ model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
+ total_params = sum(p.numel() for p in model.parameters())
+ print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights")
+
+ # Optimizer and loss
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs...")
+ start_time = time.time()
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+
+ for static_feat, target in dataloader:
+ static_feat = static_feat.to(device)
+ target = target.to(device)
+
+ optimizer.zero_grad()
+ output = model(static_feat)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+
+ if epoch % 100 == 0 or epoch == 1:
+ elapsed = time.time() - start_time
+ print(f"Epoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s")
+
+ # Save checkpoint
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernels': args.kernel_sizes,
+ 'channels': args.channels,
+ 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ }
+ }, checkpoint_path)
+ print(f" → Saved checkpoint: {checkpoint_path}")
+
+ print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
+ parser.add_argument('--input', type=str, required=True, help='Input images directory')
+ parser.add_argument('--target', type=str, required=True, help='Target images directory')
+ parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
+ help='Kernel sizes for 3 layers (default: 1 3 5)')
+ parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],
+ help='Output channels for 3 layers (default: 16 8 4)')
+ parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
+ parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
+ parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
+ help='Checkpoint directory')
+ parser.add_argument('--checkpoint-every', type=int, default=1000,
+ help='Save checkpoint every N epochs (0 = disable)')
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == '__main__':
+ main()