summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 11:48:02 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 11:48:02 +0100
commitc878631f24ddb7514dd4db3d7ace6a0a296d4157 (patch)
treea24ccffc8997a7e0cc0270c59c599ef44d0086a8
parentf4ef706409ad44cac26abb46fe8b2ddb78ec6a9c (diff)
Fix: CNN v2 training - handle variable image sizes
Training script now resizes all images to fixed size before batching. Issue: RuntimeError when batching variable-sized images - Images had different dimensions (376x626 vs 344x361) - PyTorch DataLoader requires uniform tensor sizes for batching Solution: - Add --image-size parameter (default: 256) - Resize all images to target_size using LANCZOS interpolation - Preserves aspect ratio independent training Changes: - train_cnn_v2.py: ImagePairDataset now resizes to fixed dimensions - train_cnn_v2_full.sh: Added IMAGE_SIZE=256 configuration Tested: 8 image pairs, variable sizes → uniform 256×256 batches Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
-rw-r--r--LOG.txt43
-rwxr-xr-xscripts/train_cnn_v2_full.sh131
-rwxr-xr-xtraining/train_cnn_v2.py23
3 files changed, 191 insertions, 6 deletions
diff --git a/LOG.txt b/LOG.txt
new file mode 100644
index 0000000..50b77ea
--- /dev/null
+++ b/LOG.txt
@@ -0,0 +1,43 @@
+=== CNN v2 Complete Training Pipeline ===
+Input: training/input
+Target: training/target_2
+Epochs: 10000
+Checkpoint interval: 500
+
+[1/4] Training CNN v2 model...
+Training on cpu
+Loaded 8 image pairs
+Model: [16, 8, 4] channels, [1, 3, 5] kernels, 3456 weights
+
+Training for 10000 epochs...
+Traceback (most recent call last):
+ File "/Users/skal/demo/training/train_cnn_v2.py", line 217, in <module>
+ main()
+ File "/Users/skal/demo/training/train_cnn_v2.py", line 213, in main
+ train(args)
+ File "/Users/skal/demo/training/train_cnn_v2.py", line 157, in train
+ for static_feat, target in dataloader:
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 741, in __next__
+ data = self._next_data()
+ ^^^^^^^^^^^^^^^^^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 801, in _next_data
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 57, in fetch
+ return self.collate_fn(data)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 401, in default_collate
+ return collate(batch, collate_fn_map=default_collate_fn_map)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 214, in collate
+ return [
+ ^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 215, in <listcomp>
+ collate(samples, collate_fn_map=collate_fn_map)
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
+ return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 275, in collate_tensor_fn
+ return torch.stack(batch, 0, out=out)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+RuntimeError: stack expects each tensor to be equal size, but got [8, 376, 626] at entry 0 and [8, 344, 361] at entry 1
diff --git a/scripts/train_cnn_v2_full.sh b/scripts/train_cnn_v2_full.sh
new file mode 100755
index 0000000..119b788
--- /dev/null
+++ b/scripts/train_cnn_v2_full.sh
@@ -0,0 +1,131 @@
+#!/bin/bash
+# Complete CNN v2 Training Pipeline
+# Train → Export → Build → Validate
+
+set -e
+
+PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$PROJECT_ROOT"
+
+# Configuration
+INPUT_DIR="training/input"
+TARGET_DIR="training/target_2"
+CHECKPOINT_DIR="checkpoints"
+VALIDATION_DIR="validation_results"
+EPOCHS=10000
+CHECKPOINT_EVERY=500
+BATCH_SIZE=8
+IMAGE_SIZE=256
+KERNEL_SIZES="1 3 5"
+CHANNELS="16 8 4"
+
+echo "=== CNN v2 Complete Training Pipeline ==="
+echo "Input: $INPUT_DIR"
+echo "Target: $TARGET_DIR"
+echo "Epochs: $EPOCHS"
+echo "Checkpoint interval: $CHECKPOINT_EVERY"
+echo ""
+
+# Step 1: Train model
+echo "[1/4] Training CNN v2 model..."
+python3 training/train_cnn_v2.py \
+ --input "$INPUT_DIR" \
+ --target "$TARGET_DIR" \
+ --image-size $IMAGE_SIZE \
+ --kernel-sizes $KERNEL_SIZES \
+ --channels $CHANNELS \
+ --epochs $EPOCHS \
+ --batch-size $BATCH_SIZE \
+ --checkpoint-dir "$CHECKPOINT_DIR" \
+ --checkpoint-every $CHECKPOINT_EVERY
+
+if [ $? -ne 0 ]; then
+ echo "Error: Training failed"
+ exit 1
+fi
+
+echo ""
+echo "Training complete!"
+echo ""
+
+# Step 2: Export final checkpoint to shaders
+FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth"
+
+if [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Warning: Final checkpoint not found, using latest available..."
+ FINAL_CHECKPOINT=$(ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth | head -1)
+fi
+
+echo "[2/4] Exporting final checkpoint to WGSL shaders..."
+echo "Checkpoint: $FINAL_CHECKPOINT"
+python3 training/export_cnn_v2_shader.py "$FINAL_CHECKPOINT" \
+ --output-dir workspaces/main/shaders
+
+if [ $? -ne 0 ]; then
+ echo "Error: Shader export failed"
+ exit 1
+fi
+
+echo ""
+
+# Step 3: Rebuild with new shaders
+echo "[3/4] Rebuilding demo with new shaders..."
+cmake --build build -j4 --target demo64k > /dev/null 2>&1
+
+if [ $? -ne 0 ]; then
+ echo "Error: Build failed"
+ exit 1
+fi
+
+echo " → Build complete"
+echo ""
+
+# Step 4: Visual assessment - process all checkpoints
+echo "[4/4] Visual assessment of training progression..."
+mkdir -p "$VALIDATION_DIR"
+
+# Test first input image with checkpoints at intervals
+TEST_IMAGE="$INPUT_DIR/img_000.png"
+CHECKPOINT_INTERVAL=1000
+
+echo " Processing checkpoints (every ${CHECKPOINT_INTERVAL} epochs)..."
+
+for checkpoint in "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth; do
+ epoch=$(echo "$checkpoint" | grep -o 'epoch_[0-9]*' | cut -d'_' -f2)
+
+ # Only process checkpoints at intervals
+ if [ $((epoch % CHECKPOINT_INTERVAL)) -eq 0 ] || [ "$epoch" -eq "$EPOCHS" ]; then
+ echo " Epoch $epoch..."
+
+ # Export shaders for this checkpoint
+ python3 training/export_cnn_v2_shader.py "$checkpoint" \
+ --output-dir workspaces/main/shaders > /dev/null 2>&1
+
+ # Rebuild
+ cmake --build build -j4 --target cnn_test > /dev/null 2>&1
+
+ # Process test image
+ build/cnn_test "$TEST_IMAGE" "$VALIDATION_DIR/epoch_${epoch}_output.png" 2>/dev/null
+ fi
+done
+
+# Restore final checkpoint shaders
+python3 training/export_cnn_v2_shader.py "$FINAL_CHECKPOINT" \
+ --output-dir workspaces/main/shaders > /dev/null 2>&1
+
+cmake --build build -j4 --target demo64k > /dev/null 2>&1
+
+echo ""
+echo "=== Training Pipeline Complete ==="
+echo ""
+echo "Results:"
+echo " - Checkpoints: $CHECKPOINT_DIR"
+echo " - Visual progression: $VALIDATION_DIR"
+echo " - Final shaders: workspaces/main/shaders/cnn_v2_layer_*.wgsl"
+echo ""
+echo "Opening results directory..."
+open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true
+
+echo ""
+echo "Run demo to see final result:"
+echo " ./build/demo64k"
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index fe148b4..e590b40 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -100,9 +100,10 @@ class CNNv2(nn.Module):
class ImagePairDataset(Dataset):
"""Dataset of input/target image pairs."""
- def __init__(self, input_dir, target_dir):
+ def __init__(self, input_dir, target_dir, target_size=(256, 256)):
self.input_paths = sorted(Path(input_dir).glob("*.png"))
self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.target_size = target_size
assert len(self.input_paths) == len(self.target_paths), \
f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
@@ -110,9 +111,16 @@ class ImagePairDataset(Dataset):
return len(self.input_paths)
def __getitem__(self, idx):
- # Load images
- input_img = np.array(Image.open(self.input_paths[idx]).convert('RGB')) / 255.0
- target_img = np.array(Image.open(self.target_paths[idx]).convert('RGB')) / 255.0
+ # Load and resize images to fixed size
+ input_pil = Image.open(self.input_paths[idx]).convert('RGB')
+ target_pil = Image.open(self.target_paths[idx]).convert('RGB')
+
+ # Resize to target size
+ input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
+ target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
+
+ input_img = np.array(input_pil) / 255.0
+ target_img = np.array(target_pil) / 255.0
# Compute static features
static_feat = compute_static_features(input_img.astype(np.float32))
@@ -133,9 +141,10 @@ def train(args):
print(f"Training on {device}")
# Create dataset
- dataset = ImagePairDataset(args.input, args.target)
+ target_size = (args.image_size, args.image_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- print(f"Loaded {len(dataset)} image pairs")
+ print(f"Loaded {len(dataset)} image pairs (resized to {args.image_size}x{args.image_size})")
# Create model
model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
@@ -197,6 +206,8 @@ def main():
parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
parser.add_argument('--input', type=str, required=True, help='Input images directory')
parser.add_argument('--target', type=str, required=True, help='Target images directory')
+ parser.add_argument('--image-size', type=int, default=256,
+ help='Resize images to this size (default: 256)')
parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
help='Kernel sizes for 3 layers (default: 1 3 5)')
parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],