summaryrefslogtreecommitdiff
path: root/training/train_cnn_v2.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
committerskal <pascal.massimino@gmail.com>2026-02-15 18:44:17 +0100
commit161a59fa50bb92e3664c389fa03b95aefe349b3f (patch)
tree71548f64b2bdea958388f9063b74137659d70306 /training/train_cnn_v2.py
parent9c3b72c710bf1ffa7e18f7c7390a425d57487eba (diff)
refactor(cnn): isolate CNN v2 to cnn_v2/ subdirectory
Move all CNN v2 files to dedicated cnn_v2/ directory to prepare for CNN v3 development. Zero functional changes. Structure: - cnn_v2/src/ - C++ effect implementation - cnn_v2/shaders/ - WGSL shaders (6 files) - cnn_v2/weights/ - Binary weights (3 files) - cnn_v2/training/ - Python training scripts (4 files) - cnn_v2/scripts/ - Shell scripts (train_cnn_v2_full.sh) - cnn_v2/tools/ - Validation tools (HTML) - cnn_v2/docs/ - Documentation (4 markdown files) Changes: - Update CMake source list to cnn_v2/src/cnn_v2_effect.cc - Update assets.txt with relative paths to cnn_v2/ - Update includes to ../../cnn_v2/src/cnn_v2_effect.h - Add PROJECT_ROOT resolution to Python/shell scripts - Update doc references in HOWTO.md, TODO.md - Add cnn_v2/README.md Verification: 34/34 tests passing, demo runs correctly. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn_v2.py')
-rwxr-xr-xtraining/train_cnn_v2.py472
1 files changed, 0 insertions, 472 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
deleted file mode 100755
index 9e5df2f..0000000
--- a/training/train_cnn_v2.py
+++ /dev/null
@@ -1,472 +0,0 @@
-#!/usr/bin/env python3
-"""CNN v2 Training Script - Uniform 12D→4D Architecture
-
-Architecture:
-- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias
-- Input RGBD (4D): original image mip 0
-- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
-- Per-layer kernel sizes (e.g., 1×1, 3×3, 5×5)
-- Uniform layer structure with bias=False (bias in static features)
-"""
-
-import argparse
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils.data import Dataset, DataLoader
-from pathlib import Path
-from PIL import Image
-import time
-import cv2
-
-
-def compute_static_features(rgb, depth=None, mip_level=0):
- """Generate 8D static features (parametric + spatial).
-
- Args:
- rgb: (H, W, 3) RGB image [0, 1]
- depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
- mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
-
- Returns:
- (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
-
- Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
-
- TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
- alongside mip-level indication. Current layout is hardcoded as:
- [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
- Future: Allow experimentation with different feature combinations without shader recompilation.
- Examples: [R, G, B, dx, dy, uv_x, bias] or [mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]
- """
- h, w = rgb.shape[:2]
-
- # Generate mip level for p0-p3
- if mip_level > 0:
- # Downsample to mip level
- mip_rgb = rgb.copy()
- for _ in range(mip_level):
- mip_rgb = cv2.pyrDown(mip_rgb)
- # Upsample back to original size
- for _ in range(mip_level):
- mip_rgb = cv2.pyrUp(mip_rgb)
- # Crop/pad to exact original size if needed
- if mip_rgb.shape[:2] != (h, w):
- mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
- else:
- mip_rgb = rgb
-
- # Parametric features (p0-p3) from mip level
- p0 = mip_rgb[:, :, 0].astype(np.float32)
- p1 = mip_rgb[:, :, 1].astype(np.float32)
- p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
-
- # UV coordinates (normalized [0, 1])
- uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
- uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32)
-
- # Multi-frequency position encoding
- sin20_y = np.sin(20.0 * uv_y).astype(np.float32)
-
- # Bias dimension (always 1.0) - replaces Conv2d bias parameter
- bias = np.ones((h, w), dtype=np.float32)
-
- # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias]
- features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1)
- return features
-
-
-class CNNv2(nn.Module):
- """CNN v2 - Uniform 12D→4D Architecture
-
- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
- Per-layer kernel sizes supported (e.g., [1, 3, 5])
- Uses bias=False (bias integrated in static features as 1.0)
-
- TODO: Add quantization-aware training (QAT) for 8-bit weights
- - Use torch.quantization.QuantStub/DeQuantStub
- - Train with fake quantization to adapt to 8-bit precision
- - Target: ~1.3 KB weights (vs 2.6 KB with f16)
- """
-
- def __init__(self, kernel_sizes, num_layers=3):
- super().__init__()
- if isinstance(kernel_sizes, int):
- kernel_sizes = [kernel_sizes] * num_layers
- assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
-
- self.kernel_sizes = kernel_sizes
- self.num_layers = num_layers
- self.layers = nn.ModuleList()
-
- # All layers: 12D input (4 RGBD + 8 static) → 4D output
- for kernel_size in kernel_sizes:
- self.layers.append(
- nn.Conv2d(12, 4, kernel_size=kernel_size,
- padding=kernel_size//2, bias=False)
- )
-
- def forward(self, input_rgbd, static_features):
- """Forward pass with uniform 12D→4D layers.
-
- Args:
- input_rgbd: (B, 4, H, W) input image RGBD (mip 0)
- static_features: (B, 8, H, W) static features
-
- Returns:
- (B, 4, H, W) RGBA output [0, 1]
- """
- # Layer 0: input RGBD (4D) + static (8D) = 12D
- x = torch.cat([input_rgbd, static_features], dim=1)
- x = self.layers[0](x)
- x = torch.sigmoid(x) # Soft [0,1] for layer 0
-
- # Layer 1+: previous (4D) + static (8D) = 12D
- for i in range(1, self.num_layers):
- x_input = torch.cat([x, static_features], dim=1)
- x = self.layers[i](x_input)
- if i < self.num_layers - 1:
- x = F.relu(x)
- else:
- x = torch.sigmoid(x) # Soft [0,1] for final layer
-
- return x
-
-
-class PatchDataset(Dataset):
- """Patch-based dataset extracting salient regions from images."""
-
- def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
- detector='harris', mip_level=0):
- self.input_paths = sorted(Path(input_dir).glob("*.png"))
- self.target_paths = sorted(Path(target_dir).glob("*.png"))
- self.patch_size = patch_size
- self.patches_per_image = patches_per_image
- self.detector = detector
- self.mip_level = mip_level
-
- assert len(self.input_paths) == len(self.target_paths), \
- f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
-
- print(f"Found {len(self.input_paths)} image pairs")
- print(f"Extracting {patches_per_image} patches per image using {detector} detector")
- print(f"Total patches: {len(self.input_paths) * patches_per_image}")
-
- def __len__(self):
- return len(self.input_paths) * self.patches_per_image
-
- def _detect_salient_points(self, img_array):
- """Detect salient points on original image.
-
- TODO: Add random sampling to training vectors
- - In addition to salient points, incorporate randomly-located samples
- - Default: 10% random samples, 90% salient points
- - Prevents overfitting to only high-gradient regions
- - Improves generalization across entire image
- - Configurable via --random-sample-percent parameter
- """
- gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
- h, w = gray.shape
- half_patch = self.patch_size // 2
-
- corners = None
- if self.detector == 'harris':
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch)
- elif self.detector == 'fast':
- fast = cv2.FastFeatureDetector_create(threshold=20)
- keypoints = fast.detect(gray, None)
- corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
- elif self.detector == 'shi-tomasi':
- corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
- qualityLevel=0.01, minDistance=half_patch,
- useHarrisDetector=False)
- elif self.detector == 'gradient':
- grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
- grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
- gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
- threshold = np.percentile(gradient_mag, 95)
- y_coords, x_coords = np.where(gradient_mag > threshold)
-
- if len(x_coords) > self.patches_per_image * 2:
- indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
- x_coords = x_coords[indices]
- y_coords = y_coords[indices]
-
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
-
- # Fallback to random if no corners found
- if corners is None or len(corners) == 0:
- x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
- y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
- corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
- corners = corners.reshape(-1, 1, 2)
-
- # Filter valid corners
- valid_corners = []
- for corner in corners:
- x, y = int(corner[0][0]), int(corner[0][1])
- if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
- valid_corners.append((x, y))
- if len(valid_corners) >= self.patches_per_image:
- break
-
- # Fill with random if not enough
- while len(valid_corners) < self.patches_per_image:
- x = np.random.randint(half_patch, w - half_patch)
- y = np.random.randint(half_patch, h - half_patch)
- valid_corners.append((x, y))
-
- return valid_corners
-
- def __getitem__(self, idx):
- img_idx = idx // self.patches_per_image
- patch_idx = idx % self.patches_per_image
-
- # Load original images (no resize)
- input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
- target_pil = Image.open(self.target_paths[img_idx])
- target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
-
- # Detect salient points on original image (use RGB only)
- salient_points = self._detect_salient_points(input_img)
- cx, cy = salient_points[patch_idx]
-
- # Extract patch
- half_patch = self.patch_size // 2
- y1, y2 = cy - half_patch, cy + half_patch
- x1, x2 = cx - half_patch, cx + half_patch
-
- input_patch = input_img[y1:y2, x1:x2]
- target_patch = target_img[y1:y2, x1:x2] # RGBA
-
- # Extract depth from target alpha channel (or default to 1.0)
- depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
-
- # Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
-
- # Input RGBD (mip 0) - add depth channel
- input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
-
- # Convert to tensors (C, H, W)
- input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
- static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image
-
- return input_rgbd, static_feat, target
-
-
-class ImagePairDataset(Dataset):
- """Dataset of input/target image pairs (full-image mode)."""
-
- def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0):
- self.input_paths = sorted(Path(input_dir).glob("*.png"))
- self.target_paths = sorted(Path(target_dir).glob("*.png"))
- self.target_size = target_size
- self.mip_level = mip_level
- assert len(self.input_paths) == len(self.target_paths), \
- f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
-
- def __len__(self):
- return len(self.input_paths)
-
- def __getitem__(self, idx):
- # Load and resize images to fixed size
- input_pil = Image.open(self.input_paths[idx]).convert('RGB')
- target_pil = Image.open(self.target_paths[idx])
-
- # Resize to target size
- input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
- target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
-
- input_img = np.array(input_pil) / 255.0
- target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
-
- # Extract depth from target alpha channel (or default to 1.0)
- depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
-
- # Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
-
- # Input RGBD (mip 0) - add depth channel
- h, w = input_img.shape[:2]
- input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1)
-
- # Convert to tensors (C, H, W)
- input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
- static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
- target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image
-
- return input_rgbd, static_feat, target
-
-
-def train(args):
- """Train CNN v2 model."""
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Training on {device}")
-
- # Create dataset (patch-based or full-image)
- if args.full_image:
- print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
- target_size = (args.image_size, args.image_size)
- dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- else:
- print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
- dataset = PatchDataset(args.input, args.target,
- patch_size=args.patch_size,
- patches_per_image=args.patches_per_image,
- detector=args.detector,
- mip_level=args.mip_level)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
-
- # Parse kernel sizes
- kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
- if len(kernel_sizes) == 1:
- kernel_sizes = kernel_sizes * args.num_layers
- else:
- # When multiple kernel sizes provided, derive num_layers from list length
- args.num_layers = len(kernel_sizes)
-
- # Create model
- model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
- total_params = sum(p.numel() for p in model.parameters())
- kernel_desc = ','.join(map(str, kernel_sizes))
- print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights")
- print(f"Using mip level {args.mip_level} for p0-p3 features")
-
- # Optimizer and loss
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- criterion = nn.MSELoss()
-
- # Training loop
- print(f"\nTraining for {args.epochs} epochs...")
- start_time = time.time()
-
- for epoch in range(1, args.epochs + 1):
- model.train()
- epoch_loss = 0.0
-
- for input_rgbd, static_feat, target in dataloader:
- input_rgbd = input_rgbd.to(device)
- static_feat = static_feat.to(device)
- target = target.to(device)
-
- optimizer.zero_grad()
- output = model(input_rgbd, static_feat)
-
- # Compute loss (grayscale or RGBA)
- if args.grayscale_loss:
- # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B
- output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3]
- target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
- loss = criterion(output_gray, target_gray)
- else:
- loss = criterion(output, target)
-
- loss.backward()
- optimizer.step()
-
- epoch_loss += loss.item()
-
- avg_loss = epoch_loss / len(dataloader)
-
- # Print loss at every epoch (overwrite line with \r)
- elapsed = time.time() - start_time
- print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
-
- # Save checkpoint
- if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
- print() # Newline before checkpoint message
- checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
- checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': avg_loss,
- 'config': {
- 'kernel_sizes': kernel_sizes,
- 'num_layers': args.num_layers,
- 'mip_level': args.mip_level,
- 'grayscale_loss': args.grayscale_loss,
- 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
- }
- }, checkpoint_path)
- print(f" → Saved checkpoint: {checkpoint_path}")
-
- # Always save final checkpoint
- print() # Newline after training
- final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
- final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- 'epoch': args.epochs,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': avg_loss,
- 'config': {
- 'kernel_sizes': kernel_sizes,
- 'num_layers': args.num_layers,
- 'mip_level': args.mip_level,
- 'grayscale_loss': args.grayscale_loss,
- 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
- }
- }, final_checkpoint)
- print(f" → Saved final checkpoint: {final_checkpoint}")
-
- print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
- return model
-
-
-def main():
- parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
- parser.add_argument('--input', type=str, required=True, help='Input images directory')
- parser.add_argument('--target', type=str, required=True, help='Target images directory')
-
- # Training mode
- parser.add_argument('--full-image', action='store_true',
- help='Use full-image mode (resize all images)')
- parser.add_argument('--image-size', type=int, default=256,
- help='Full-image mode: resize to this size (default: 256)')
-
- # Patch-based mode (default)
- parser.add_argument('--patch-size', type=int, default=32,
- help='Patch mode: patch size (default: 32)')
- parser.add_argument('--patches-per-image', type=int, default=64,
- help='Patch mode: patches per image (default: 64)')
- parser.add_argument('--detector', type=str, default='harris',
- choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
- help='Patch mode: salient point detector (default: harris)')
- # TODO: Add --random-sample-percent parameter (default: 10)
- # Mix salient points with random samples for better generalization
-
- # Model architecture
- parser.add_argument('--kernel-sizes', type=str, default='3',
- help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)')
- parser.add_argument('--num-layers', type=int, default=3,
- help='Number of CNN layers (default: 3)')
- parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3],
- help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)')
-
- # Training parameters
- parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
- parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
- parser.add_argument('--grayscale-loss', action='store_true',
- help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA')
- parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
- help='Checkpoint directory')
- parser.add_argument('--checkpoint-every', type=int, default=1000,
- help='Save checkpoint every N epochs (0 = disable)')
-
- args = parser.parse_args()
- train(args)
-
-
-if __name__ == '__main__':
- main()