# CNN Training Tools PyTorch-based training for image-to-image stylization with patch extraction. --- ## Quick Start ```bash # Patch-based (recommended) python3 train_cnn.py \ --input training/input --target training/output \ --patch-size 32 --patches-per-image 64 --detector harris \ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16 # Full-image (legacy) python3 train_cnn.py \ --input training/input --target training/output \ --layers 3 --kernel-sizes 3,5,3 --epochs 10000 --batch-size 8 ``` --- ## Patch-Based Training (Recommended) Extracts patches at salient points, preserves natural pixel scale. ### Detectors | Detector | Best For | Speed | |----------|----------|-------| | `harris` (default) | Corners, structured scenes | Medium | | `fast` | Dense features, textures | Fast | | `shi-tomasi` | High-quality corners | Medium | | `gradient` | Edges, high-contrast areas | Fast | ### Examples **Single layer, Harris corners:** ```bash python3 train_cnn.py --input training/input --target training/output \ --patch-size 32 --patches-per-image 64 --detector harris \ --layers 1 --kernel-sizes 3 --epochs 2000 ``` **Multi-layer, FAST features:** ```bash python3 train_cnn.py --input training/input --target training/output \ --patch-size 32 --patches-per-image 128 --detector fast \ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16 ``` **Edge-focused (gradient detector):** ```bash python3 train_cnn.py --input training/input --target training/output \ --patch-size 16 --patches-per-image 96 --detector gradient \ --layers 2 --kernel-sizes 3,3 --epochs 3000 ``` ### Benefits - **Preserves scale:** No resize distortion - **More samples:** 64 patches × 10 images = 640 samples vs 10 - **Focused learning:** Trains on interesting features, not flat areas - **Better generalization:** Network sees diverse local patterns --- ## Options | Option | Default | Description | |--------|---------|-------------| | `--input` | *required* | Input image directory | | `--target` | *required* | Target image directory | | `--patch-size` | None | Patch size (e.g., 32). Omit for full-image mode | | `--patches-per-image` | 64 | Patches to extract per image | | `--detector` | harris | harris\|fast\|shi-tomasi\|gradient | | `--layers` | 1 | Number of CNN layers | | `--kernel-sizes` | 3 | Comma-separated (e.g., 3,5,3) | | `--epochs` | 100 | Training epochs | | `--batch-size` | 4 | Batch size (use 16 for patches, 8 for full-image) | | `--learning-rate` | 0.001 | Learning rate | | `--checkpoint-every` | 0 | Save every N epochs (0=off) | | `--resume` | None | Resume from checkpoint | | `--export-only` | None | Export WGSL without training | | `--infer` | None | Generate ground truth PNG | --- ## Export & Validation **Export shaders from checkpoint:** ```bash python3 train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth ``` **Generate ground truth for comparison:** ```bash python3 train_cnn.py --infer input.png \ --export-only checkpoints/checkpoint_epoch_5000.pth \ --output ground_truth.png ``` **Auto-generates:** - `cnn_weights_generated.wgsl` - Weight arrays - `cnn_layer.wgsl` - Layer shader with correct architecture --- ## Workflow ### 1. Render Raw Frames ```bash ./build/demo64k --headless --duration 5 --output training/input/ ``` ### 2. Generate Stylized Targets ```bash python3 training/image_style_processor.py training/input/ training/output/ pencil_sketch ``` **Available styles:** pencil_sketch, ink_drawing, charcoal_pastel, glitch_art, circuit_board, wireframe_topo ### 3. Train CNN ```bash python3 train_cnn.py \ --input training/input --target training/output \ --patch-size 32 --patches-per-image 64 \ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --checkpoint-every 1000 ``` ### 4. Rebuild Demo ```bash cmake --build build -j4 && ./build/demo64k ``` --- ## Architecture **Input:** 7 channels = [RGBD, UV coords, grayscale] normalized to [-1,1] **Output:** Grayscale [0,1] **Layers:** - **Inner (0..N-2):** Conv2d(7→4) + tanh → RGBD output [-1,1] - **Final (N-1):** Conv2d(7→1) + clamp(0,1) → Grayscale output **Coordinate awareness:** Layer 0 receives UV coords for position-dependent effects (vignetting, radial gradients). --- ## Tips - **Training data:** 10-50 image pairs recommended - **Patch size:** 32×32 good balance (16×16 for detail, 64×64 for context) - **Patches per image:** 64-128 for good coverage - **Batch size:** Higher for patches (16) vs full-image (8) - **Checkpoints:** Save every 500-1000 epochs - **Loss plateaus:** Lower learning rate (0.0001) or add layers --- ## Requirements ```bash pip install torch torchvision pillow opencv-python numpy ``` --- ## References - **CNN Effect:** `doc/CNN_EFFECT.md` - **Timeline:** `doc/SEQUENCE.md` - **HOWTO:** `doc/HOWTO.md`