diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 11:48:02 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 11:48:02 +0100 |
| commit | c878631f24ddb7514dd4db3d7ace6a0a296d4157 (patch) | |
| tree | a24ccffc8997a7e0cc0270c59c599ef44d0086a8 /training/train_cnn_v2.py | |
| parent | f4ef706409ad44cac26abb46fe8b2ddb78ec6a9c (diff) | |
Fix: CNN v2 training - handle variable image sizes
Training script now resizes all images to fixed size before batching.
Issue: RuntimeError when batching variable-sized images
- Images had different dimensions (376x626 vs 344x361)
- PyTorch DataLoader requires uniform tensor sizes for batching
Solution:
- Add --image-size parameter (default: 256)
- Resize all images to target_size using LANCZOS interpolation
- Preserves aspect ratio independent training
Changes:
- train_cnn_v2.py: ImagePairDataset now resizes to fixed dimensions
- train_cnn_v2_full.sh: Added IMAGE_SIZE=256 configuration
Tested: 8 image pairs, variable sizes → uniform 256×256 batches
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index fe148b4..e590b40 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -100,9 +100,10 @@ class CNNv2(nn.Module): class ImagePairDataset(Dataset): """Dataset of input/target image pairs.""" - def __init__(self, input_dir, target_dir): + def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.target_size = target_size assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -110,9 +111,16 @@ class ImagePairDataset(Dataset): return len(self.input_paths) def __getitem__(self, idx): - # Load images - input_img = np.array(Image.open(self.input_paths[idx]).convert('RGB')) / 255.0 - target_img = np.array(Image.open(self.target_paths[idx]).convert('RGB')) / 255.0 + # Load and resize images to fixed size + input_pil = Image.open(self.input_paths[idx]).convert('RGB') + target_pil = Image.open(self.target_paths[idx]).convert('RGB') + + # Resize to target size + input_pil = input_pil.resize(self.target_size, Image.LANCZOS) + target_pil = target_pil.resize(self.target_size, Image.LANCZOS) + + input_img = np.array(input_pil) / 255.0 + target_img = np.array(target_pil) / 255.0 # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) @@ -133,9 +141,10 @@ def train(args): print(f"Training on {device}") # Create dataset - dataset = ImagePairDataset(args.input, args.target) + target_size = (args.image_size, args.image_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - print(f"Loaded {len(dataset)} image pairs") + print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})") # Create model model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) @@ -197,6 +206,8 @@ def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') + parser.add_argument('--image-size', type=int, default=256, + help='Resize images to this size (default: 256)') parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], help='Kernel sizes for 3 layers (default: 1 3 5)') parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], |
