diff options
Diffstat (limited to 'cnn_v3/docs/HOWTO.md')
| -rw-r--r-- | cnn_v3/docs/HOWTO.md | 70 |
1 files changed, 59 insertions, 11 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index 425a33b..0cf2fe5 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -135,20 +135,68 @@ Mix freely; the dataloader treats all sample directories uniformly. ## 3. Training -*(Script not yet written — see TODO.md. Architecture spec in `CNN_V3.md` §Training.)* +Two source files: +- **`cnn_v3_utils.py`** — image I/O, feature assembly, channel dropout, salient-point + detection, `CNNv3Dataset` +- **`train_cnn_v3.py`** — `CNNv3` model, training loop, CLI + +### Quick start -**Planned command:** ```bash -python3 cnn_v3/training/train_cnn_v3.py \ - --dataset dataset/ \ - --epochs 500 \ - --output cnn_v3/weights/cnn_v3_weights.bin +cd cnn_v3/training + +# Patch-based (default) — 64×64 patches around Harris corners +python3 train_cnn_v3.py \ + --input dataset/ \ + --input-mode simple \ + --epochs 200 + +# Full-image mode (resizes to 256×256) +python3 train_cnn_v3.py \ + --input dataset/ \ + --input-mode full \ + --full-image --image-size 256 \ + --epochs 500 + +# Quick smoke test: 1 epoch, small patches, random detector +python3 train_cnn_v3.py \ + --input dataset/ --epochs 1 \ + --patch-size 32 --detector random ``` -**FiLM conditioning** during training: -- Beat/audio inputs randomized per sample -- MLP: `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net -- Output: γ/β for enc0(4ch) + enc1(8ch) + dec1(4ch) + dec0(4ch) = 40 floats +### Key flags + +| Flag | Default | Notes | +|------|---------|-------| +| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs | +| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer | +| `--patch-size N` | `64` | Patch crop size | +| `--patches-per-image N` | `256` | Patches extracted per image per epoch | +| `--detector` | `harris` | `harris` \| `shi-tomasi` \| `fast` \| `gradient` \| `random` | +| `--channel-dropout-p F` | `0.3` | Dropout prob for geometric channels | +| `--full-image` | off | Resize full image instead of cropping patches | +| `--enc-channels C` | `4,8` | Encoder channel counts, comma-separated | +| `--film-cond-dim N` | `5` | FiLM conditioning input size | +| `--epochs N` | `200` | Training epochs | +| `--batch-size N` | `16` | Batch size | +| `--lr F` | `1e-3` | Adam learning rate | +| `--checkpoint-dir DIR` | `checkpoints/` | Where to save `.pth` files | +| `--checkpoint-every N` | `50` | Epoch interval for checkpoints (0=disable) | + +### FiLM conditioning during training + +- Conditioning vector `[beat_phase, beat_time/8, audio_intensity, style_p0, style_p1]` + is **randomised per sample** (uniform [0,1]) so the MLP trains jointly with the U-Net. +- At inference, real beat/audio values are fed from `CNNv3Effect::set_film_params()`. + +### Channel dropout + +Applied per-sample in `cnn_v3_utils.apply_channel_dropout()`: +- Geometric channels (normal, depth, depth_grad) zeroed with `p=channel_dropout_p` +- Context channels (mat_id, shadow, transp) with `p≈0.2` +- Temporal channels (prev.rgb) with `p=0.5` + +This ensures the network works for both full G-buffer and photo-only inputs. --- @@ -202,7 +250,7 @@ Test vectors generated by `cnn_v3/training/gen_test_vectors.py` (PyTorch referen | 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet | | 4 — C++ CNNv3Effect | ✅ Done | FiLM uniform upload, 36/36 tests pass | | 5 — Parity validation | ✅ Done | test_cnn_v3_parity.cc, max_err=4.88e-4 | -| 6 — FiLM MLP training | TODO | train_cnn_v3.py not yet written | +| 6 — FiLM MLP training | ✅ Done | train_cnn_v3.py + cnn_v3_utils.py written | --- |
