summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 08:27:39 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 08:27:39 +0100
commit3e4fece8fce11b368b4c7bab284242bf18e6a0b1 (patch)
tree108682f727f7668e1df346563f4576d7e567dcd2 /cnn_v3
parent64095c683f15e8bd7c19d32041fcc81b1bd6c214 (diff)
feat(cnn_v3/training): add --single-sample option + doc fixes
- train_cnn_v3.py: --single-sample <dir> implies --full-image + --batch-size 1 - cnn_v3_utils.py: CNNv3Dataset accepts single_sample= kwarg (explicit override) - HOWTO.md: document --single-sample workflow, fix pack_photo_sample.py usage (--target required) - HOW_TO_CNN.md: fix GBufferEffect seq input (prev_cnn→source), fix binary name (demo→demo64k), add --resume to flag table, remove stale "pack without target" block handoff(Gemini): --single-sample <dir> added to train_cnn_v3.py; docs audited and corrected
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/docs/HOWTO.md28
-rw-r--r--cnn_v3/docs/HOW_TO_CNN.md30
-rw-r--r--cnn_v3/training/cnn_v3_utils.py25
-rw-r--r--cnn_v3/training/train_cnn_v3.py7
4 files changed, 64 insertions, 26 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 58f09ed..1aead68 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -233,12 +233,13 @@ channel-dropout training.
```bash
python3 cnn_v3/training/pack_photo_sample.py \
- --photo cnn_v3/training/input/photo1.jpg \
+ --photo input/photo1.jpg \
+ --target target/photo1_styled.png \
--output dataset/photos/sample_001/
```
-The output `target.png` defaults to the input photo (no style). Copy in
-your stylized version as `target.png` before training.
+`--target` is required and must be a stylized ground-truth image at the same
+resolution as the photo. The script writes it as `target.png` in the sample dir.
### Dataset layout
@@ -285,10 +286,31 @@ python3 train_cnn_v3.py \
--patch-size 32 --detector random
```
+### Single-sample training
+
+Use `--single-sample <dir>` to train on one specific sample directory.
+Implies `--full-image` and `--batch-size 1` automatically.
+
+```bash
+# Pack input/target pair into a sample directory first
+python3 pack_photo_sample.py \
+ --photo input/photo1.png \
+ --target target/photo1_styled.png \
+ --output dataset/simple/sample_001/
+
+# Train on that sample only
+python3 train_cnn_v3.py \
+ --single-sample dataset/simple/sample_001/ \
+ --epochs 500
+```
+
+All other flags (`--epochs`, `--lr`, `--checkpoint-dir`, `--enc-channels`, etc.) work normally.
+
### Key flags
| Flag | Default | Notes |
|------|---------|-------|
+| `--single-sample DIR` | — | Train on one sample dir; implies `--full-image`, `--batch-size 1` |
| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs |
| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
| `--patch-size N` | `64` | Patch crop size |
diff --git a/cnn_v3/docs/HOW_TO_CNN.md b/cnn_v3/docs/HOW_TO_CNN.md
index 624deaa..f5f1b1a 100644
--- a/cnn_v3/docs/HOW_TO_CNN.md
+++ b/cnn_v3/docs/HOW_TO_CNN.md
@@ -107,15 +107,6 @@ The network learns the mapping `albedo → target`. If you pass the same image a
input and target, the network learns identity (useful as sanity check, not for real
training). Confirm `target.png` looks correct before running training.
-**Alternative — pack without a target yet:**
-```bash
-python3 pack_photo_sample.py \
- --photo /path/to/photo.png \
- --output dataset/simple/sample_001/
-# target.png defaults to a copy of the input; replace it before training:
-cp my_stylized_version.png dataset/simple/sample_001/target.png
-```
-
**Batch packing:**
```bash
for f in photos/*.png; do
@@ -336,10 +327,22 @@ uv run train_cnn_v3.py \
--epochs 500
```
+**Single-sample training (overfit on one input/target pair):**
+```bash
+# Pack first
+./gen_sample.sh input/photo.png target/photo_styled.png dataset/simple/sample_001/
+
+# Train — --full-image and --batch-size 1 are implied
+uv run train_cnn_v3.py \
+ --single-sample dataset/simple/sample_001/ \
+ --epochs 500
+```
+
### Flag reference
| Flag | Default | Notes |
|------|---------|-------|
+| `--single-sample DIR` | — | Train on one sample dir; implies `--full-image`, `--batch-size 1` |
| `--input DIR` | `training/dataset` | Dataset root; always set explicitly |
| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
| `--epochs N` | 200 | 500 recommended for full-image mode |
@@ -356,6 +359,7 @@ uv run train_cnn_v3.py \
| `--film-cond-dim N` | 5 | Must match `CNNv3FiLMParams` field count in C++ |
| `--checkpoint-dir DIR` | `checkpoints/` | Set per-experiment |
| `--checkpoint-every N` | 50 | 0 to disable intermediate checkpoints |
+| `--resume [CKPT]` | — | Resume from checkpoint path; if path missing, uses latest in `--checkpoint-dir` |
### Architecture at startup
@@ -564,10 +568,12 @@ It owns:
```
SEQUENCE 0 0 "Scene with CNN v3"
- EFFECT + GBufferEffect prev_cnn -> gbuf_feat0 gbuf_feat1 0 60
- EFFECT + CNNv3Effect gbuf_feat0 gbuf_feat1 -> sink 0 60
+ EFFECT + GBufferEffect source -> gbuf_feat0 gbuf_feat1 0 60
+ EFFECT + CNNv3Effect gbuf_feat0 gbuf_feat1 -> sink 0 60
```
+Temporal feedback (`prev_cnn`) is wired automatically by `wire_dag()` — no explicit input needed in the `.seq` file.
+
Or direct C++:
```cpp
#include "cnn_v3/src/cnn_v3_effect.h"
@@ -658,7 +664,7 @@ Do not reference them from outside the effect unless debugging.
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j4
-./build/demo
+./build/demo64k
```
### Expected visual output
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..50707a2 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -286,7 +286,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +297,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index de10d6a..31cfd9d 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -104,6 +104,10 @@ def train(args):
enc_channels = [int(c) for c in args.enc_channels.split(',')]
print(f"Device: {device}")
+ if args.single_sample:
+ args.full_image = True
+ args.batch_size = 1
+
dataset = CNNv3Dataset(
dataset_dir=args.input,
input_mode=args.input_mode,
@@ -115,6 +119,7 @@ def train(args):
detector=args.detector,
augment=True,
patch_search_window=args.patch_search_window,
+ single_sample=args.single_sample,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -222,6 +227,8 @@ def main():
p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
# Dataset
+ p.add_argument('--single-sample', default='', metavar='DIR',
+ help='Train on a single sample directory; implies --full-image and --batch-size 1')
p.add_argument('--input', default='training/dataset',
help='Dataset root (contains full/ or simple/ subdirs)')
p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],