summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py23
1 files changed, 17 insertions, 6 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index fe148b4..e590b40 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -100,9 +100,10 @@ class CNNv2(nn.Module):
class ImagePairDataset(Dataset):
"""Dataset of input/target image pairs."""
- def __init__(self, input_dir, target_dir):
+ def __init__(self, input_dir, target_dir, target_size=(256, 256)):
self.input_paths = sorted(Path(input_dir).glob("*.png"))
self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.target_size = target_size
assert len(self.input_paths) == len(self.target_paths), \
f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
@@ -110,9 +111,16 @@ class ImagePairDataset(Dataset):
return len(self.input_paths)
def __getitem__(self, idx):
- # Load images
- input_img = np.array(Image.open(self.input_paths[idx]).convert('RGB')) / 255.0
- target_img = np.array(Image.open(self.target_paths[idx]).convert('RGB')) / 255.0
+ # Load and resize images to fixed size
+ input_pil = Image.open(self.input_paths[idx]).convert('RGB')
+ target_pil = Image.open(self.target_paths[idx]).convert('RGB')
+
+ # Resize to target size
+ input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
+ target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
+
+ input_img = np.array(input_pil) / 255.0
+ target_img = np.array(target_pil) / 255.0
# Compute static features
static_feat = compute_static_features(input_img.astype(np.float32))
@@ -133,9 +141,10 @@ def train(args):
print(f"Training on {device}")
# Create dataset
- dataset = ImagePairDataset(args.input, args.target)
+ target_size = (args.image_size, args.image_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- print(f"Loaded {len(dataset)} image pairs")
+ print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})")
# Create model
model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
@@ -197,6 +206,8 @@ def main():
parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
parser.add_argument('--input', type=str, required=True, help='Input images directory')
parser.add_argument('--target', type=str, required=True, help='Target images directory')
+ parser.add_argument('--image-size', type=int, default=256,
+ help='Resize images to this size (default: 256)')
parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
help='Kernel sizes for 3 layers (default: 1 3 5)')
parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],