From 7a0d9970c9b283b957f8b8df0b940813afb28ec2 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 16:41:35 +0100 Subject: CNN v2: Add --mip-level option for parametric features Add mip level control for p0-p3 features (0=original, 1=half, 2=quarter, 3=eighth). Uses pyrDown/pyrUp for proper Gaussian filtering during mip generation. Changes: - compute_static_features(): Accept mip_level param, generate mip via cv2 pyramid - PatchDataset/ImagePairDataset: Pass mip_level to feature computation - CLI: Add --mip-level arg with choices [0,1,2,3] - Save mip_level in checkpoint config for tracking - Doc updates: HOWTO.md and CNN_V2.md Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn_v2.py | 49 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 14 deletions(-) (limited to 'training/train_cnn_v2.py') diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index dc087c6..3d49d13 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -21,26 +21,40 @@ import time import cv2 -def compute_static_features(rgb, depth=None): +def compute_static_features(rgb, depth=None, mip_level=0): """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional + mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] - Note: p0-p3 are parametric features (can be mips, gradients, etc.) - For training, we use RGBD as default, but could use mip1/2 + Note: p0-p3 are parametric features generated from specified mip level """ h, w = rgb.shape[:2] - # Parametric features (p0-p3) - using RGBD as default - # TODO: Experiment with mip1 grayscale, gradients, etc. - p0 = rgb[:, :, 0].astype(np.float32) - p1 = rgb[:, :, 1].astype(np.float32) - p2 = rgb[:, :, 2].astype(np.float32) + # Generate mip level for p0-p3 + if mip_level > 0: + # Downsample to mip level + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + # Upsample back to original size + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + # Crop/pad to exact original size if needed + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features (p0-p3) from mip level + p0 = mip_rgb[:, :, 0].astype(np.float32) + p1 = mip_rgb[:, :, 1].astype(np.float32) + p2 = mip_rgb[:, :, 2].astype(np.float32) p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) @@ -119,12 +133,13 @@ class PatchDataset(Dataset): """Patch-based dataset extracting salient regions from images.""" def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, - detector='harris'): + detector='harris', mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.patch_size = patch_size self.patches_per_image = patches_per_image self.detector = detector + self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -224,7 +239,7 @@ class PatchDataset(Dataset): target_patch = target_img[y1:y2, x1:x2] # RGBA # Compute static features for patch - static_feat = compute_static_features(input_patch.astype(np.float32)) + static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) @@ -240,10 +255,11 @@ class PatchDataset(Dataset): class ImagePairDataset(Dataset): """Dataset of input/target image pairs (full-image mode).""" - def __init__(self, input_dir, target_dir, target_size=(256, 256)): + def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) self.target_size = target_size + self.mip_level = mip_level assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -263,7 +279,7 @@ class ImagePairDataset(Dataset): target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha # Compute static features - static_feat = compute_static_features(input_img.astype(np.float32)) + static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] @@ -286,14 +302,15 @@ def train(args): if args.full_image: print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})") target_size = (args.image_size, args.image_size) - dataset = ImagePairDataset(args.input, args.target, target_size=target_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) else: print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)") dataset = PatchDataset(args.input, args.target, patch_size=args.patch_size, patches_per_image=args.patches_per_image, - detector=args.detector) + detector=args.detector, + mip_level=args.mip_level) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Parse kernel sizes @@ -306,6 +323,7 @@ def train(args): total_params = sum(p.numel() for p in model.parameters()) kernel_desc = ','.join(map(str, kernel_sizes)) print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights") + print(f"Using mip level {args.mip_level} for p0-p3 features") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -351,6 +369,7 @@ def train(args): 'config': { 'kernel_sizes': kernel_sizes, 'num_layers': args.num_layers, + 'mip_level': args.mip_level, 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) @@ -387,6 +406,8 @@ def main(): help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)') parser.add_argument('--num-layers', type=int, default=3, help='Number of CNN layers (default: 3)') + parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3], + help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') -- cgit v1.2.3