From 7a0d9970c9b283b957f8b8df0b940813afb28ec2 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 16:41:35 +0100 Subject: CNN v2: Add --mip-level option for parametric features Add mip level control for p0-p3 features (0=original, 1=half, 2=quarter, 3=eighth). Uses pyrDown/pyrUp for proper Gaussian filtering during mip generation. Changes: - compute_static_features(): Accept mip_level param, generate mip via cv2 pyramid - PatchDataset/ImagePairDataset: Pass mip_level to feature computation - CLI: Add --mip-level arg with choices [0,1,2,3] - Save mip_level in checkpoint config for tracking - Doc updates: HOWTO.md and CNN_V2.md Co-Authored-By: Claude Sonnet 4.5 --- doc/CNN_V2.md | 30 +++++++++++++++++++++++------ doc/HOWTO.md | 6 ++++++ training/train_cnn_v2.py | 49 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 65 insertions(+), 20 deletions(-) diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md index e56b022..49086ca 100644 --- a/doc/CNN_V2.md +++ b/doc/CNN_V2.md @@ -245,12 +245,28 @@ fn pack_channels(values: vec4) -> vec4 { **Static Feature Extraction:** ```python -def compute_static_features(rgb, depth): - """Generate parametric features (8D: p0-p3 + spatial).""" +def compute_static_features(rgb, depth, mip_level=0): + """Generate parametric features (8D: p0-p3 + spatial). + + Args: + mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res + """ h, w = rgb.shape[:2] - # Parametric features (example: use input RGBD, but could be mips/gradients) - p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth + # Generate mip level for p0-p3 (downsample then upsample) + if mip_level > 0: + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features from mip level + p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth # UV coordinates (normalized) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) @@ -308,6 +324,7 @@ class CNNv2(nn.Module): # Hyperparameters kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5]) num_layers = 3 # Number of CNN layers +mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth learning_rate = 1e-3 batch_size = 16 epochs = 5000 @@ -318,8 +335,8 @@ epochs = 5000 # Training loop (standard PyTorch f32) for epoch in range(epochs): for rgb_batch, depth_batch, target_batch in dataloader: - # Compute static features (8D) - static_feat = compute_static_features(rgb_batch, depth_batch) + # Compute static features (8D) with mip level + static_feat = compute_static_features(rgb_batch, depth_batch, mip_level) # Input RGBD (4D) input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) @@ -342,6 +359,7 @@ torch.save({ 'config': { 'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes 'num_layers': 3, + 'mip_level': 0, # Mip level used for p0-p3 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] }, 'epoch': epoch, diff --git a/doc/HOWTO.md b/doc/HOWTO.md index 9c67106..9003fe1 100644 --- a/doc/HOWTO.md +++ b/doc/HOWTO.md @@ -166,6 +166,12 @@ Config: 100 epochs, 3×3 kernels, 8→4→4 channels, patch-based (harris detect --input training/input/ --target training/target_2/ \ --kernel-sizes 1,3,5 \ --epochs 5000 --batch-size 16 + +# Mip-level for p0-p3 features (0=original, 1=half, 2=quarter, 3=eighth) +./training/train_cnn_v2.py \ + --input training/input/ --target training/target_2/ \ + --mip-level 1 \ + --epochs 100 --batch-size 16 ``` **Export Binary Weights:** diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index dc087c6..3d49d13 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -21,26 +21,40 @@ import time import cv2 -def compute_static_features(rgb, depth=None): +def compute_static_features(rgb, depth=None, mip_level=0): """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional + mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] - Note: p0-p3 are parametric features (can be mips, gradients, etc.) - For training, we use RGBD as default, but could use mip1/2 + Note: p0-p3 are parametric features generated from specified mip level """ h, w = rgb.shape[:2] - # Parametric features (p0-p3) - using RGBD as default - # TODO: Experiment with mip1 grayscale, gradients, etc. - p0 = rgb[:, :, 0].astype(np.float32) - p1 = rgb[:, :, 1].astype(np.float32) - p2 = rgb[:, :, 2].astype(np.float32) + # Generate mip level for p0-p3 + if mip_level > 0: + # Downsample to mip level + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + # Upsample back to original size + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + # Crop/pad to exact original size if needed + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features (p0-p3) from mip level + p0 = mip_rgb[:, :, 0].astype(np.float32) + p1 = mip_rgb[:, :, 1].astype(np.float32) + p2 = mip_rgb[:, :, 2].astype(np.float32) p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) @@ -119,12 +133,13 @@ class PatchDataset(Dataset): """Patch-based dataset extracting salient regions from images.""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, - detector='harris'): + detector='harris', mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector + self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -224,7 +239,7 @@ class PatchDataset(Dataset): target_patch = target_img[y1:y2, x1:x2] # RGBA # Compute static features for patch - static_feat = compute_static_features(input_patch.astype(np.float32)) + static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) @@ -240,10 +255,11 @@ class PatchDataset(Dataset): class ImagePairDataset(Dataset): """Dataset of input/target image pairs (full-image mode).""" - def __init__(self, input_dir, target_dir, target_size=(256, 256)): + def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size + self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -263,7 +279,7 @@ class ImagePairDataset(Dataset): target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha # Compute static features - static_feat = compute_static_features(input_img.astype(np.float32)) + static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] @@ -286,14 +302,15 @@ def train(args): if args.full_image: print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") target_size = (args.image_size, args.image_size) - dataset = ImagePairDataset(args.input, args.target, target_size=target_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) else: print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, - detector=args.detector) + detector=args.detector, + mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes @@ -306,6 +323,7 @@ def train(args): total_params = sum(p.numel() for p in model.parameters()) kernel_desc = ','.join(map(str, kernel_sizes)) print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights") + print(f"Using mip level {args.mip_level} for p0-p3 features") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -351,6 +369,7 @@ def train(args): 'config': { 'kernel_sizes': kernel_sizes, 'num_layers': args.num_layers, + 'mip_level': args.mip_level, 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) @@ -387,6 +406,8 @@ def main(): help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)') parser.add_argument('--num-layers', type=int, default=3, help='Number of CNN layers (default: 3)') + parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3], + help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') -- cgit v1.2.3