summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-10 22:54:38 +0100
committerskal <pascal.massimino@gmail.com>2026-02-10 22:54:38 +0100
commit2adcf1bac1ec651861930eb2af00641eb23f6ef1 (patch)
tree2208b753b68783be5a26906b5bc1690c23a267d0 /training
parent58f276378735e0b51f4d1517a844357e45e376a7 (diff)
docs: Update CNN training documentation with patch extraction
Streamlined and updated all training docs with new patch-based approach. Changes: - HOWTO.md: Updated training section with patch/full-image examples - CNN_EFFECT.md: Streamlined training workflow, added detector info - training/README.md: Complete rewrite with detector comparison table New sections: - Detector comparison (harris, fast, shi-tomasi, gradient) - Practical examples for different use cases - Tips for patch size and batch size selection - Benefits of patch-based training Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rw-r--r--training/README.md193
1 files changed, 95 insertions, 98 deletions
diff --git a/training/README.md b/training/README.md
index 0a46718..e78b471 100644
--- a/training/README.md
+++ b/training/README.md
@@ -1,117 +1,109 @@
# CNN Training Tools
-Tools for training and preparing data for the CNN post-processing effect.
+PyTorch-based training for image-to-image stylization with patch extraction.
---
-## train_cnn.py
-
-PyTorch-based training script for image-to-image stylization.
-
-### Basic Usage
+## Quick Start
```bash
-python3 train_cnn.py --input <input_dir> --target <target_dir> [options]
+# Patch-based (recommended)
+python3 train_cnn.py \
+ --input training/input --target training/output \
+ --patch-size 32 --patches-per-image 64 --detector harris \
+ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16
+
+# Full-image (legacy)
+python3 train_cnn.py \
+ --input training/input --target training/output \
+ --layers 3 --kernel-sizes 3,5,3 --epochs 10000 --batch-size 8
```
+---
+
+## Patch-Based Training (Recommended)
+
+Extracts patches at salient points, preserves natural pixel scale.
+
+### Detectors
+
+| Detector | Best For | Speed |
+|----------|----------|-------|
+| `harris` (default) | Corners, structured scenes | Medium |
+| `fast` | Dense features, textures | Fast |
+| `shi-tomasi` | High-quality corners | Medium |
+| `gradient` | Edges, high-contrast areas | Fast |
+
### Examples
-**Single layer, 3×3 kernel:**
+**Single layer, Harris corners:**
```bash
python3 train_cnn.py --input training/input --target training/output \
- --layers 1 --kernel-sizes 3 --epochs 500
+ --patch-size 32 --patches-per-image 64 --detector harris \
+ --layers 1 --kernel-sizes 3 --epochs 2000
```
-**Multi-layer, mixed kernels:**
+**Multi-layer, FAST features:**
```bash
python3 train_cnn.py --input training/input --target training/output \
- --layers 3 --kernel-sizes 3,5,3 --epochs 1000
+ --patch-size 32 --patches-per-image 128 --detector fast \
+ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --batch-size 16
```
-**With checkpointing:**
+**Edge-focused (gradient detector):**
```bash
python3 train_cnn.py --input training/input --target training/output \
- --epochs 500 --checkpoint-every 50
+ --patch-size 16 --patches-per-image 96 --detector gradient \
+ --layers 2 --kernel-sizes 3,3 --epochs 3000
```
-**Resume from checkpoint:**
-```bash
-python3 train_cnn.py --input training/input --target training/output \
- --resume training/checkpoints/checkpoint_epoch_200.pth
-```
+### Benefits
+
+- **Preserves scale:** No resize distortion
+- **More samples:** 64 patches × 10 images = 640 samples vs 10
+- **Focused learning:** Trains on interesting features, not flat areas
+- **Better generalization:** Network sees diverse local patterns
-### Options
+---
+
+## Options
| Option | Default | Description |
|--------|---------|-------------|
| `--input` | *required* | Input image directory |
| `--target` | *required* | Target image directory |
+| `--patch-size` | None | Patch size (e.g., 32). Omit for full-image mode |
+| `--patches-per-image` | 64 | Patches to extract per image |
+| `--detector` | harris | harris\|fast\|shi-tomasi\|gradient |
| `--layers` | 1 | Number of CNN layers |
-| `--kernel-sizes` | 3 | Comma-separated kernel sizes (auto-repeats if single value) |
+| `--kernel-sizes` | 3 | Comma-separated (e.g., 3,5,3) |
| `--epochs` | 100 | Training epochs |
-| `--batch-size` | 4 | Batch size |
+| `--batch-size` | 4 | Batch size (use 16 for patches, 8 for full-image) |
| `--learning-rate` | 0.001 | Learning rate |
-| `--output` | `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl` | Output WGSL file |
-| `--checkpoint-every` | 0 | Save checkpoint every N epochs (0=disabled) |
-| `--checkpoint-dir` | `training/checkpoints` | Checkpoint directory |
-| `--resume` | None | Resume from checkpoint file |
-
-### Architecture
-
-- **Layer 0:** `CoordConv2d` - accepts (x,y) patch center + 3×3 RGBA samples
-- **Layers 1+:** Standard `Conv2d` - 3×3 RGBA samples only
-- **Activation:** Tanh between layers
-- **Output:** Residual connection (30% stylization blend)
-
-### Requirements
-
-```bash
-pip install torch torchvision pillow
-```
+| `--checkpoint-every` | 0 | Save every N epochs (0=off) |
+| `--resume` | None | Resume from checkpoint |
+| `--export-only` | None | Export WGSL without training |
+| `--infer` | None | Generate ground truth PNG |
---
-## image_style_processor.py
-
-Generates stylized target images from raw renders.
-
-### Usage
+## Export & Validation
+**Export shaders from checkpoint:**
```bash
-python3 image_style_processor.py <input_dir> <output_dir> <style>
+python3 train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth
```
-### Available Styles
-
-**Sketch:**
-- `pencil_sketch` - Dense cross-hatching
-- `ink_drawing` - Bold outlines, comic style
-- `charcoal_pastel` - Soft, dramatic contrasts
-- `conte_crayon` - Directional strokes
-- `gesture_sketch` - Loose, energetic lines
-
-**Futuristic:**
-- `circuit_board` - Tech blueprint
-- `glitch_art` - Digital corruption
-- `wireframe_topo` - Topographic contours
-- `data_mosaic` - Voronoi fragmentation
-- `holographic_scan` - CRT/HUD aesthetic
-
-### Examples
-
+**Generate ground truth for comparison:**
```bash
-# Generate pencil sketch targets
-python3 image_style_processor.py input/ output/ pencil_sketch
-
-# Generate glitch art targets
-python3 image_style_processor.py input/ output/ glitch_art
+python3 train_cnn.py --infer input.png \
+ --export-only checkpoints/checkpoint_epoch_5000.pth \
+ --output ground_truth.png
```
-### Requirements
-
-```bash
-pip install opencv-python numpy
-```
+**Auto-generates:**
+- `cnn_weights_generated.wgsl` - Weight arrays
+- `cnn_layer.wgsl` - Layer shader with correct architecture
---
@@ -119,64 +111,69 @@ pip install opencv-python numpy
### 1. Render Raw Frames
-Generate raw 3D renders as input:
```bash
./build/demo64k --headless --duration 5 --output training/input/
```
### 2. Generate Stylized Targets
-Apply artistic style:
```bash
python3 training/image_style_processor.py training/input/ training/output/ pencil_sketch
```
+**Available styles:** pencil_sketch, ink_drawing, charcoal_pastel, glitch_art, circuit_board, wireframe_topo
+
### 3. Train CNN
-Train network to reproduce the style:
```bash
-python3 training/train_cnn.py \
- --input training/input \
- --target training/output \
- --epochs 500 \
- --checkpoint-every 50
+python3 train_cnn.py \
+ --input training/input --target training/output \
+ --patch-size 32 --patches-per-image 64 \
+ --layers 3 --kernel-sizes 3,5,3 --epochs 5000 --checkpoint-every 1000
```
### 4. Rebuild Demo
-Weights auto-exported to `cnn_weights_generated.wgsl`:
```bash
-cmake --build build -j4
-./build/demo64k
+cmake --build build -j4 && ./build/demo64k
```
---
-## Tips
+## Architecture
-- **Training data:** 10-50 image pairs recommended
-- **Resolution:** 256×256 (auto-resized during training)
-- **Checkpoints:** Save every 50-100 epochs for long runs
-- **Loss plateaus:** Try lower learning rate (0.0001) or more layers
-- **Residual connection:** Prevents catastrophic divergence (input always blended in)
+**Input:** 7 channels = [RGBD, UV coords, grayscale] normalized to [-1,1]
+**Output:** Grayscale [0,1]
+
+**Layers:**
+- **Inner (0..N-2):** Conv2d(7→4) + tanh → RGBD output [-1,1]
+- **Final (N-1):** Conv2d(7→1) + clamp(0,1) → Grayscale output
+
+**Coordinate awareness:** Layer 0 receives UV coords for position-dependent effects (vignetting, radial gradients).
---
-## Coordinate-Aware Layer 0
+## Tips
-Layer 0 receives normalized (x,y) patch center coordinates, enabling position-dependent effects:
+- **Training data:** 10-50 image pairs recommended
+- **Patch size:** 32×32 good balance (16×16 for detail, 64×64 for context)
+- **Patches per image:** 64-128 for good coverage
+- **Batch size:** Higher for patches (16) vs full-image (8)
+- **Checkpoints:** Save every 500-1000 epochs
+- **Loss plateaus:** Lower learning rate (0.0001) or add layers
-- **Vignetting:** Darker edges
-- **Radial gradients:** Center-focused stylization
-- **Corner effects:** Edge-specific treatments
+---
-Training coordinate grid is auto-generated during forward pass. No manual intervention needed.
+## Requirements
-Size impact: +32B coord weights (kernel-agnostic).
+```bash
+pip install torch torchvision pillow opencv-python numpy
+```
---
## References
-- **CNN Effect Documentation:** `doc/CNN_EFFECT.md`
-- **Training Architecture:** See `train_cnn.py` (CoordConv2d class)
+- **CNN Effect:** `doc/CNN_EFFECT.md`
+- **Timeline:** `doc/SEQUENCE.md`
+- **HOWTO:** `doc/HOWTO.md`