summaryrefslogtreecommitdiff
path: root/cnn_v3/docs/HOWTO.md
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/docs/HOWTO.md')
-rw-r--r--cnn_v3/docs/HOWTO.md70
1 files changed, 59 insertions, 11 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 425a33b..0cf2fe5 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -135,20 +135,68 @@ Mix freely; the dataloader treats all sample directories uniformly.
## 3. Training
-*(Script not yet written — see TODO.md. Architecture spec in `CNN_V3.md` §Training.)*
+Two source files:
+- **`cnn_v3_utils.py`** — image I/O, feature assembly, channel dropout, salient-point
+ detection, `CNNv3Dataset`
+- **`train_cnn_v3.py`** — `CNNv3` model, training loop, CLI
+
+### Quick start
-**Planned command:**
```bash
-python3 cnn_v3/training/train_cnn_v3.py \
- --dataset dataset/ \
- --epochs 500 \
- --output cnn_v3/weights/cnn_v3_weights.bin
+cd cnn_v3/training
+
+# Patch-based (default) — 64×64 patches around Harris corners
+python3 train_cnn_v3.py \
+ --input dataset/ \
+ --input-mode simple \
+ --epochs 200
+
+# Full-image mode (resizes to 256×256)
+python3 train_cnn_v3.py \
+ --input dataset/ \
+ --input-mode full \
+ --full-image --image-size 256 \
+ --epochs 500
+
+# Quick smoke test: 1 epoch, small patches, random detector
+python3 train_cnn_v3.py \
+ --input dataset/ --epochs 1 \
+ --patch-size 32 --detector random
```
-**FiLM conditioning** during training:
-- Beat/audio inputs randomized per sample
-- MLP: `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net
-- Output: γ/β for enc0(4ch) + enc1(8ch) + dec1(4ch) + dec0(4ch) = 40 floats
+### Key flags
+
+| Flag | Default | Notes |
+|------|---------|-------|
+| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs |
+| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
+| `--patch-size N` | `64` | Patch crop size |
+| `--patches-per-image N` | `256` | Patches extracted per image per epoch |
+| `--detector` | `harris` | `harris` \| `shi-tomasi` \| `fast` \| `gradient` \| `random` |
+| `--channel-dropout-p F` | `0.3` | Dropout prob for geometric channels |
+| `--full-image` | off | Resize full image instead of cropping patches |
+| `--enc-channels C` | `4,8` | Encoder channel counts, comma-separated |
+| `--film-cond-dim N` | `5` | FiLM conditioning input size |
+| `--epochs N` | `200` | Training epochs |
+| `--batch-size N` | `16` | Batch size |
+| `--lr F` | `1e-3` | Adam learning rate |
+| `--checkpoint-dir DIR` | `checkpoints/` | Where to save `.pth` files |
+| `--checkpoint-every N` | `50` | Epoch interval for checkpoints (0=disable) |
+
+### FiLM conditioning during training
+
+- Conditioning vector `[beat_phase, beat_time/8, audio_intensity, style_p0, style_p1]`
+ is **randomised per sample** (uniform [0,1]) so the MLP trains jointly with the U-Net.
+- At inference, real beat/audio values are fed from `CNNv3Effect::set_film_params()`.
+
+### Channel dropout
+
+Applied per-sample in `cnn_v3_utils.apply_channel_dropout()`:
+- Geometric channels (normal, depth, depth_grad) zeroed with `p=channel_dropout_p`
+- Context channels (mat_id, shadow, transp) with `p≈0.2`
+- Temporal channels (prev.rgb) with `p=0.5`
+
+This ensures the network works for both full G-buffer and photo-only inputs.
---
@@ -202,7 +250,7 @@ Test vectors generated by `cnn_v3/training/gen_test_vectors.py` (PyTorch referen
| 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet |
| 4 — C++ CNNv3Effect | ✅ Done | FiLM uniform upload, 36/36 tests pass |
| 5 — Parity validation | ✅ Done | test_cnn_v3_parity.cc, max_err=4.88e-4 |
-| 6 — FiLM MLP training | TODO | train_cnn_v3.py not yet written |
+| 6 — FiLM MLP training | ✅ Done | train_cnn_v3.py + cnn_v3_utils.py written |
---