From c878631f24ddb7514dd4db3d7ace6a0a296d4157 Mon Sep 17 00:00:00 2001 From: skal Date: Thu, 12 Feb 2026 11:48:02 +0100 Subject: Fix: CNN v2 training - handle variable image sizes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Training script now resizes all images to fixed size before batching. Issue: RuntimeError when batching variable-sized images - Images had different dimensions (376x626 vs 344x361) - PyTorch DataLoader requires uniform tensor sizes for batching Solution: - Add --image-size parameter (default: 256) - Resize all images to target_size using LANCZOS interpolation - Preserves aspect ratio independent training Changes: - train_cnn_v2.py: ImagePairDataset now resizes to fixed dimensions - train_cnn_v2_full.sh: Added IMAGE_SIZE=256 configuration Tested: 8 image pairs, variable sizes → uniform 256×256 batches Co-Authored-By: Claude Sonnet 4.5 --- LOG.txt | 43 ++++++++++++++ scripts/train_cnn_v2_full.sh | 131 +++++++++++++++++++++++++++++++++++++++++++ training/train_cnn_v2.py | 23 ++++++-- 3 files changed, 191 insertions(+), 6 deletions(-) create mode 100644 LOG.txt create mode 100755 scripts/train_cnn_v2_full.sh diff --git a/LOG.txt b/LOG.txt new file mode 100644 index 0000000..50b77ea --- /dev/null +++ b/LOG.txt @@ -0,0 +1,43 @@ +=== CNN v2 Complete Training Pipeline === +Input: training/input +Target: training/target_2 +Epochs: 10000 +Checkpoint interval: 500 + +[1/4] Training CNN v2 model... +Training on cpu +Loaded 8 image pairs +Model: [16, 8, 4] channels, [1, 3, 5] kernels, 3456 weights + +Training for 10000 epochs... +Traceback (most recent call last): + File "/Users/skal/demo/training/train_cnn_v2.py", line 217, in + main() + File "/Users/skal/demo/training/train_cnn_v2.py", line 213, in main + train(args) + File "/Users/skal/demo/training/train_cnn_v2.py", line 157, in train + for static_feat, target in dataloader: + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 741, in __next__ + data = self._next_data() + ^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 801, in _next_data + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 57, in fetch + return self.collate_fn(data) + ^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 401, in default_collate + return collate(batch, collate_fn_map=default_collate_fn_map) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 214, in collate + return [ + ^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 215, in + collate(samples, collate_fn_map=collate_fn_map) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate + return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 275, in collate_tensor_fn + return torch.stack(batch, 0, out=out) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: stack expects each tensor to be equal size, but got [8, 376, 626] at entry 0 and [8, 344, 361] at entry 1 diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh new file mode 100755 index 0000000..119b788 --- /dev/null +++ b/scripts/train_cnn_v2_full.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Complete CNN v2 Training Pipeline +# Train → Export → Build → Validate + +set -e + +PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$PROJECT_ROOT" + +# Configuration +INPUT_DIR="training/input" +TARGET_DIR="training/target_2" +CHECKPOINT_DIR="checkpoints" +VALIDATION_DIR="validation_results" +EPOCHS=10000 +CHECKPOINT_EVERY=500 +BATCH_SIZE=8 +IMAGE_SIZE=256 +KERNEL_SIZES="1 3 5" +CHANNELS="16 8 4" + +echo "=== CNN v2 Complete Training Pipeline ===" +echo "Input: $INPUT_DIR" +echo "Target: $TARGET_DIR" +echo "Epochs: $EPOCHS" +echo "Checkpoint interval: $CHECKPOINT_EVERY" +echo "" + +# Step 1: Train model +echo "[1/4] Training CNN v2 model..." +python3 training/train_cnn_v2.py \ + --input "$INPUT_DIR" \ + --target "$TARGET_DIR" \ + --image-size $IMAGE_SIZE \ + --kernel-sizes $KERNEL_SIZES \ + --channels $CHANNELS \ + --epochs $EPOCHS \ + --batch-size $BATCH_SIZE \ + --checkpoint-dir "$CHECKPOINT_DIR" \ + --checkpoint-every $CHECKPOINT_EVERY + +if [ $? -ne 0 ]; then + echo "Error: Training failed" + exit 1 +fi + +echo "" +echo "Training complete!" +echo "" + +# Step 2: Export final checkpoint to shaders +FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth" + +if [ ! -f "$FINAL_CHECKPOINT" ]; then + echo "Warning: Final checkpoint not found, using latest available..." + FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1) +fi + +echo "[2/4] Exporting final checkpoint to WGSL shaders..." +echo "Checkpoint: $FINAL_CHECKPOINT" +python3 training/export_cnn_v2_shader.py "$FINAL_CHECKPOINT" \ + --output-dir workspaces/main/shaders + +if [ $? -ne 0 ]; then + echo "Error: Shader export failed" + exit 1 +fi + +echo "" + +# Step 3: Rebuild with new shaders +echo "[3/4] Rebuilding demo with new shaders..." +cmake --build build -j4 --target demo64k > /dev/null 2>&1 + +if [ $? -ne 0 ]; then + echo "Error: Build failed" + exit 1 +fi + +echo " → Build complete" +echo "" + +# Step 4: Visual assessment - process all checkpoints +echo "[4/4] Visual assessment of training progression..." +mkdir -p "$VALIDATION_DIR" + +# Test first input image with checkpoints at intervals +TEST_IMAGE="$INPUT_DIR/img_000.png" +CHECKPOINT_INTERVAL=1000 + +echo " Processing checkpoints (every ${CHECKPOINT_INTERVAL} epochs)..." + +for checkpoint in "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth; do + epoch=$(echo "$checkpoint" | grep -o 'epoch_[0-9]*' | cut -d'_' -f2) + + # Only process checkpoints at intervals + if [ $((epoch % CHECKPOINT_INTERVAL)) -eq 0 ] || [ "$epoch" -eq "$EPOCHS" ]; then + echo " Epoch $epoch..." + + # Export shaders for this checkpoint + python3 training/export_cnn_v2_shader.py "$checkpoint" \ + --output-dir workspaces/main/shaders > /dev/null 2>&1 + + # Rebuild + cmake --build build -j4 --target cnn_test > /dev/null 2>&1 + + # Process test image + build/cnn_test "$TEST_IMAGE" "$VALIDATION_DIR/epoch_${epoch}_output.png" 2>/dev/null + fi +done + +# Restore final checkpoint shaders +python3 training/export_cnn_v2_shader.py "$FINAL_CHECKPOINT" \ + --output-dir workspaces/main/shaders > /dev/null 2>&1 + +cmake --build build -j4 --target demo64k > /dev/null 2>&1 + +echo "" +echo "=== Training Pipeline Complete ===" +echo "" +echo "Results:" +echo " - Checkpoints: $CHECKPOINT_DIR" +echo " - Visual progression: $VALIDATION_DIR" +echo " - Final shaders: workspaces/main/shaders/cnn_v2_layer_*.wgsl" +echo "" +echo "Opening results directory..." +open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true + +echo "" +echo "Run demo to see final result:" +echo " ./build/demo64k" diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index fe148b4..e590b40 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -100,9 +100,10 @@ class CNNv2(nn.Module): class ImagePairDataset(Dataset): """Dataset of input/target image pairs.""" - def __init__(self, input_dir, target_dir): + def __init__(self, input_dir, target_dir, target_size=(256, 256)): self.input_paths = sorted(Path(input_dir).glob("*.png")) self.target_paths = sorted(Path(target_dir).glob("*.png")) + self.target_size = target_size assert len(self.input_paths) == len(self.target_paths), \ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets" @@ -110,9 +111,16 @@ class ImagePairDataset(Dataset): return len(self.input_paths) def __getitem__(self, idx): - # Load images - input_img = np.array(Image.open(self.input_paths[idx]).convert('RGB')) / 255.0 - target_img = np.array(Image.open(self.target_paths[idx]).convert('RGB')) / 255.0 + # Load and resize images to fixed size + input_pil = Image.open(self.input_paths[idx]).convert('RGB') + target_pil = Image.open(self.target_paths[idx]).convert('RGB') + + # Resize to target size + input_pil = input_pil.resize(self.target_size, Image.LANCZOS) + target_pil = target_pil.resize(self.target_size, Image.LANCZOS) + + input_img = np.array(input_pil) / 255.0 + target_img = np.array(target_pil) / 255.0 # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) @@ -133,9 +141,10 @@ def train(args): print(f"Training on {device}") # Create dataset - dataset = ImagePairDataset(args.input, args.target) + target_size = (args.image_size, args.image_size) + dataset = ImagePairDataset(args.input, args.target, target_size=target_size) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) - print(f"Loaded {len(dataset)} image pairs") + print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})") # Create model model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) @@ -197,6 +206,8 @@ def main(): parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features') parser.add_argument('--input', type=str, required=True, help='Input images directory') parser.add_argument('--target', type=str, required=True, help='Target images directory') + parser.add_argument('--image-size', type=int, default=256, + help='Resize images to this size (default: 256)') parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], help='Kernel sizes for 3 layers (default: 1 3 5)') parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], -- cgit v1.2.3