diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 22:54:38 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 22:54:38 +0100 |
| commit | 2adcf1bac1ec651861930eb2af00641eb23f6ef1 (patch) | |
| tree | 2208b753b68783be5a26906b5bc1690c23a267d0 /training | |
| parent | 58f276378735e0b51f4d1517a844357e45e376a7 (diff) | |
docs: Update CNN training documentation with patch extraction
Streamlined and updated all training docs with new patch-based approach.
Changes:
- HOWTO.md: Updated training section with patch/full-image examples
- CNN_EFFECT.md: Streamlined training workflow, added detector info
- training/README.md: Complete rewrite with detector comparison table
New sections:
- Detector comparison (harris, fast, shi-tomasi, gradient)
- Practical examples for different use cases
- Tips for patch size and batch size selection
- Benefits of patch-based training
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rw-r--r-- | training/README.md | 193 |
1 files changed, 95 insertions, 98 deletions
diff --git a/training/README.md b/training/README.md index 0a46718..e78b471 100644 --- a/training/README.md +++ b/training/README.md @@ -1,117 +1,109 @@ # CNN Training Tools -Tools for training and preparing data for the CNN post-processing effect. +PyTorch-based training for image-to-image stylization with patch extraction. --- -## train_cnn.py - -PyTorch-based training script for image-to-image stylization. - -### Basic Usage +## Quick Start ```bash -python3 train_cnn.py --input <input_dir> --target <target_dir> [options] +# Patch-based (recommended) +python3 train_cnn.py \ + --input training/input --target training/output \ + --patch-size 32 --patches-per-image 64 --detector harris \ + --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16 + +# Full-image (legacy) +python3 train_cnn.py \ + --input training/input --target training/output \ + --layers 3 --kernel-sizes 3,5,3 --epochs 10000 --batch-size 8 ``` +--- + +## Patch-Based Training (Recommended) + +Extracts patches at salient points, preserves natural pixel scale. + +### Detectors + +| Detector | Best For | Speed | +|----------|----------|-------| +| `harris` (default) | Corners, structured scenes | Medium | +| `fast` | Dense features, textures | Fast | +| `shi-tomasi` | High-quality corners | Medium | +| `gradient` | Edges, high-contrast areas | Fast | + ### Examples -**Single layer, 3×3 kernel:** +**Single layer, Harris corners:** ```bash python3 train_cnn.py --input training/input --target training/output \ - --layers 1 --kernel-sizes 3 --epochs 500 + --patch-size 32 --patches-per-image 64 --detector harris \ + --layers 1 --kernel-sizes 3 --epochs 2000 ``` -**Multi-layer, mixed kernels:** +**Multi-layer, FAST features:** ```bash python3 train_cnn.py --input training/input --target training/output \ - --layers 3 --kernel-sizes 3,5,3 --epochs 1000 + --patch-size 32 --patches-per-image 128 --detector fast \ + --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16 ``` -**With checkpointing:** +**Edge-focused (gradient detector):** ```bash python3 train_cnn.py --input training/input --target training/output \ - --epochs 500 --checkpoint-every 50 + --patch-size 16 --patches-per-image 96 --detector gradient \ + --layers 2 --kernel-sizes 3,3 --epochs 3000 ``` -**Resume from checkpoint:** -```bash -python3 train_cnn.py --input training/input --target training/output \ - --resume training/checkpoints/checkpoint_epoch_200.pth -``` +### Benefits + +- **Preserves scale:** No resize distortion +- **More samples:** 64 patches × 10 images = 640 samples vs 10 +- **Focused learning:** Trains on interesting features, not flat areas +- **Better generalization:** Network sees diverse local patterns -### Options +--- + +## Options | Option | Default | Description | |--------|---------|-------------| | `--input` | *required* | Input image directory | | `--target` | *required* | Target image directory | +| `--patch-size` | None | Patch size (e.g., 32). Omit for full-image mode | +| `--patches-per-image` | 64 | Patches to extract per image | +| `--detector` | harris | harris\|fast\|shi-tomasi\|gradient | | `--layers` | 1 | Number of CNN layers | -| `--kernel-sizes` | 3 | Comma-separated kernel sizes (auto-repeats if single value) | +| `--kernel-sizes` | 3 | Comma-separated (e.g., 3,5,3) | | `--epochs` | 100 | Training epochs | -| `--batch-size` | 4 | Batch size | +| `--batch-size` | 4 | Batch size (use 16 for patches, 8 for full-image) | | `--learning-rate` | 0.001 | Learning rate | -| `--output` | `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl` | Output WGSL file | -| `--checkpoint-every` | 0 | Save checkpoint every N epochs (0=disabled) | -| `--checkpoint-dir` | `training/checkpoints` | Checkpoint directory | -| `--resume` | None | Resume from checkpoint file | - -### Architecture - -- **Layer 0:** `CoordConv2d` - accepts (x,y) patch center + 3×3 RGBA samples -- **Layers 1+:** Standard `Conv2d` - 3×3 RGBA samples only -- **Activation:** Tanh between layers -- **Output:** Residual connection (30% stylization blend) - -### Requirements - -```bash -pip install torch torchvision pillow -``` +| `--checkpoint-every` | 0 | Save every N epochs (0=off) | +| `--resume` | None | Resume from checkpoint | +| `--export-only` | None | Export WGSL without training | +| `--infer` | None | Generate ground truth PNG | --- -## image_style_processor.py - -Generates stylized target images from raw renders. - -### Usage +## Export & Validation +**Export shaders from checkpoint:** ```bash -python3 image_style_processor.py <input_dir> <output_dir> <style> +python3 train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth ``` -### Available Styles - -**Sketch:** -- `pencil_sketch` - Dense cross-hatching -- `ink_drawing` - Bold outlines, comic style -- `charcoal_pastel` - Soft, dramatic contrasts -- `conte_crayon` - Directional strokes -- `gesture_sketch` - Loose, energetic lines - -**Futuristic:** -- `circuit_board` - Tech blueprint -- `glitch_art` - Digital corruption -- `wireframe_topo` - Topographic contours -- `data_mosaic` - Voronoi fragmentation -- `holographic_scan` - CRT/HUD aesthetic - -### Examples - +**Generate ground truth for comparison:** ```bash -# Generate pencil sketch targets -python3 image_style_processor.py input/ output/ pencil_sketch - -# Generate glitch art targets -python3 image_style_processor.py input/ output/ glitch_art +python3 train_cnn.py --infer input.png \ + --export-only checkpoints/checkpoint_epoch_5000.pth \ + --output ground_truth.png ``` -### Requirements - -```bash -pip install opencv-python numpy -``` +**Auto-generates:** +- `cnn_weights_generated.wgsl` - Weight arrays +- `cnn_layer.wgsl` - Layer shader with correct architecture --- @@ -119,64 +111,69 @@ pip install opencv-python numpy ### 1. Render Raw Frames -Generate raw 3D renders as input: ```bash ./build/demo64k --headless --duration 5 --output training/input/ ``` ### 2. Generate Stylized Targets -Apply artistic style: ```bash python3 training/image_style_processor.py training/input/ training/output/ pencil_sketch ``` +**Available styles:** pencil_sketch, ink_drawing, charcoal_pastel, glitch_art, circuit_board, wireframe_topo + ### 3. Train CNN -Train network to reproduce the style: ```bash -python3 training/train_cnn.py \ - --input training/input \ - --target training/output \ - --epochs 500 \ - --checkpoint-every 50 +python3 train_cnn.py \ + --input training/input --target training/output \ + --patch-size 32 --patches-per-image 64 \ + --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --checkpoint-every 1000 ``` ### 4. Rebuild Demo -Weights auto-exported to `cnn_weights_generated.wgsl`: ```bash -cmake --build build -j4 -./build/demo64k +cmake --build build -j4 && ./build/demo64k ``` --- -## Tips +## Architecture -- **Training data:** 10-50 image pairs recommended -- **Resolution:** 256×256 (auto-resized during training) -- **Checkpoints:** Save every 50-100 epochs for long runs -- **Loss plateaus:** Try lower learning rate (0.0001) or more layers -- **Residual connection:** Prevents catastrophic divergence (input always blended in) +**Input:** 7 channels = [RGBD, UV coords, grayscale] normalized to [-1,1] +**Output:** Grayscale [0,1] + +**Layers:** +- **Inner (0..N-2):** Conv2d(7→4) + tanh → RGBD output [-1,1] +- **Final (N-1):** Conv2d(7→1) + clamp(0,1) → Grayscale output + +**Coordinate awareness:** Layer 0 receives UV coords for position-dependent effects (vignetting, radial gradients). --- -## Coordinate-Aware Layer 0 +## Tips -Layer 0 receives normalized (x,y) patch center coordinates, enabling position-dependent effects: +- **Training data:** 10-50 image pairs recommended +- **Patch size:** 32×32 good balance (16×16 for detail, 64×64 for context) +- **Patches per image:** 64-128 for good coverage +- **Batch size:** Higher for patches (16) vs full-image (8) +- **Checkpoints:** Save every 500-1000 epochs +- **Loss plateaus:** Lower learning rate (0.0001) or add layers -- **Vignetting:** Darker edges -- **Radial gradients:** Center-focused stylization -- **Corner effects:** Edge-specific treatments +--- -Training coordinate grid is auto-generated during forward pass. No manual intervention needed. +## Requirements -Size impact: +32B coord weights (kernel-agnostic). +```bash +pip install torch torchvision pillow opencv-python numpy +``` --- ## References -- **CNN Effect Documentation:** `doc/CNN_EFFECT.md` -- **Training Architecture:** See `train_cnn.py` (CoordConv2d class) +- **CNN Effect:** `doc/CNN_EFFECT.md` +- **Timeline:** `doc/SEQUENCE.md` +- **HOWTO:** `doc/HOWTO.md` |
