summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 08:07:53 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 08:07:53 +0100
commit64095c683f15e8bd7c19d32041fcc81b1bd6c214 (patch)
tree91fa6100377d1deb66ac21b3860e15b3dd4958b5 /cnn_v3
parenta71c95c8caf7e570c3f484ce1a53b7acb5ef2006 (diff)
feat(cnn_v3): add infer_cnn_v3.py + rewrite cnn_test for v3 parity
- cnn_v3/training/infer_cnn_v3.py: PyTorch inference tool; simple mode (single PNG, zeroed geometry) and full mode (sample directory); supports --identity-film (γ=1 β=0) to match C++ default, --cond for FiLM MLP, --blend, --debug-hex for pixel comparison - tools/cnn_test.cc: full rewrite, v3 only; packs 20-channel features on CPU (training format: [0,1] oct normals, pyrdown mip), uploads to GPU, runs CNNv3Effect, reads back RGBA16Float, saves PNG; --sample-dir for full G-buffer input, --weights for .bin override, --debug-hex - cmake/DemoTests.cmake: add cnn_v3/src include path, drop unused offscreen_render_target.cc from cnn_test sources - cnn_v3/docs/HOWTO.md: new §10 documenting both tools, comparison workflow, and feature-format convention (training vs runtime) handoff(Gemini): cnn_test + infer_cnn_v3.py ready for parity testing. Run both with --identity-film / --debug-hex on same image to compare.
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/docs/HOWTO.md138
-rw-r--r--cnn_v3/training/infer_cnn_v3.py219
2 files changed, 356 insertions, 1 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 5cfc371..58f09ed 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -587,9 +587,145 @@ Visualization panel still works.
---
-## 10. See Also
+## 10. Python / WGSL Parity Check (infer_cnn_v3 + cnn_test)
+
+Two complementary tools for comparing PyTorch inference against the live WGSL
+compute shaders on the same input image.
+
+### 10a. infer_cnn_v3.py — PyTorch reference inference
+
+**Location:** `cnn_v3/training/infer_cnn_v3.py`
+
+Runs the trained `CNNv3` model in Python and saves the RGBA output as PNG.
+
+**Simple mode** (single PNG, geometry zeroed):
+```bash
+cd cnn_v3/training
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth
+```
+
+**Full mode** (sample directory with all G-buffer files):
+```bash
+python3 infer_cnn_v3.py dataset/simple/sample_000/ out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth
+```
+
+**Identity FiLM** — bypass MLP, use γ=1 β=0 (matches C++ `cnn_test` default):
+```bash
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth \
+ --identity-film
+```
+
+**Options:**
+
+| Flag | Default | Description |
+|------|---------|-------------|
+| `--checkpoint CKPT` | auto-find latest | Path to `.pth` checkpoint |
+| `--enc-channels C` | from checkpoint | `4,8` — must match training config |
+| `--cond F F F F F` | `0 0 0 0 0` | FiLM conditioning (beat_phase, beat_norm, audio, style0, style1) |
+| `--identity-film` | off | Bypass FiLM MLP, use γ=1 β=0 |
+| `--blend F` | `1.0` | Blend with albedo: 0=input, 1=CNN |
+| `--debug-hex` | off | Print first 8 output pixels as hex |
+
+In **simple mode**, geometry channels are zeroed: `normal=(0.5,0.5)` (oct-encodes
+to ≈(0,0,1)), `depth=0`, `matid=0`, `shadow=1`, `transp=0`.
+
+The checkpoint `config` dict (saved by `train_cnn_v3.py`) sets `enc_channels`
+and `film_cond_dim` automatically; `--enc-channels` is only needed if the
+checkpoint lacks a config key.
+
+---
+
+### 10b. cnn_test — WGSL / GPU reference inference
+
+**Location:** `tools/cnn_test.cc` **Binary:** `build/cnn_test`
+
+Packs the same 20-channel feature tensor as `infer_cnn_v3.py`, uploads it to
+GPU, runs the five `CNNv3Effect` compute passes, and saves the RGBA16Float
+output as PNG.
+
+**Build** (requires `DEMO_BUILD_TESTS=ON` or `DEMO_WORKSPACE=main`):
+```bash
+cmake -B build -DDEMO_BUILD_TESTS=ON && cmake --build build -j4 --target cnn_test
+```
+
+**Simple mode:**
+```bash
+./build/cnn_test photo.png out_gpu.png --weights workspaces/main/weights/cnn_v3_weights.bin
+```
+
+**Full mode** (sample directory):
+```bash
+./build/cnn_test dataset/simple/sample_000/albedo.png out_gpu.png \
+ --sample-dir dataset/simple/sample_000/ \
+ --weights workspaces/main/weights/cnn_v3_weights.bin
+```
+
+**Options:**
+
+| Flag | Description |
+|------|-------------|
+| `--sample-dir DIR` | Load all G-buffer files (albedo/normal/depth/matid/shadow/transp) |
+| `--weights FILE` | `cnn_v3_weights.bin` (uses asset-embedded weights if omitted) |
+| `--debug-hex` | Print first 8 output pixels as hex |
+| `--help` | Show usage |
+
+FiLM is always **identity** (γ=1, β=0) — matching the C++ `CNNv3Effect` default
+until GPU-side FiLM MLP evaluation is added.
+
+---
+
+### 10c. Side-by-side comparison
+
+For a pixel-accurate comparison, use `--identity-film` in Python and `--debug-hex`
+in both tools:
+
+```bash
+cd cnn_v3/training
+
+# 1. Python inference (identity FiLM)
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth \
+ --identity-film --debug-hex
+
+# 2. GPU inference (always identity FiLM)
+./build/cnn_test photo.png out_gpu.png \
+ --weights workspaces/main/weights/cnn_v3_weights.bin \
+ --debug-hex
+```
+
+Both tools print first 8 pixels in the same format:
+```
+ [0] 0x7F804000 (0.4980 0.5020 0.2510 0.0000)
+```
+
+**Expected delta:** ≤ 1/255 (≈ 4e-3) per channel, matching the parity test
+(`test_cnn_v3_parity`). Larger deltas indicate a weight mismatch — re-export
+with `export_cnn_v3_weights.py` and verify the `.bin` size is 3928 bytes.
+
+---
+
+### 10d. Feature format note
+
+Both tools pack features in **training format** ([0,1] oct-encoded normals),
+not the runtime `gbuf_pack.wgsl` format (which remaps normals to [-1,1]).
+This makes `infer_cnn_v3.py` ↔ `cnn_test` directly comparable.
+
+The live pipeline (`GBufferEffect → gbuf_pack.wgsl → CNNv3Effect`) uses [-1,1]
+normals — that is the intended inference distribution after a full training run
+with `--input-mode full` (Blender renders). For training on photos
+(`--input-mode simple`), [0,1] normals are correct since channel dropout
+teaches the network to handle absent geometry.
+
+---
+
+## 11. See Also
- `cnn_v3/docs/CNN_V3.md` — Full architecture design (U-Net, FiLM, feature layout)
- `doc/EFFECT_WORKFLOW.md` — General effect integration guide
- `cnn_v2/docs/CNN_V2.md` — Reference implementation (simpler, operational)
- `src/tests/gpu/test_demo_effects.cc` — GBufferEffect + GBufViewEffect tests
+- `src/tests/gpu/test_cnn_v3_parity.cc` — Zero/random weight parity tests
+- `cnn_v3/training/export_cnn_v3_weights.py` — Export trained checkpoint → `.bin`
diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py
new file mode 100644
index 0000000..ca1c72a
--- /dev/null
+++ b/cnn_v3/training/infer_cnn_v3.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["torch", "numpy", "pillow", "opencv-python"]
+# ///
+"""CNN v3 PyTorch inference — compare with cnn_test (WGSL/GPU output).
+
+Simple mode (single PNG): albedo = photo, geometry channels zeroed.
+Full mode (sample dir): loads all G-buffer files via assemble_features.
+
+Usage:
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint checkpoints/ckpt.pth
+ python3 infer_cnn_v3.py sample_000/ out.png --checkpoint ckpt.pth
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --identity-film
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --cond 0.5 0.0 0.8 0.0 0.0
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+sys.path.insert(0, str(Path(__file__).parent))
+from train_cnn_v3 import CNNv3
+from cnn_v3_utils import assemble_features, load_rgb, load_rg, load_depth16, load_gray
+
+
+# ---------------------------------------------------------------------------
+# Feature loading
+# ---------------------------------------------------------------------------
+
+def load_sample_dir(sample_dir: Path) -> np.ndarray:
+ """Load all G-buffer files from a sample directory → (H,W,20) f32."""
+ return assemble_features(
+ load_rgb(sample_dir / 'albedo.png'),
+ load_rg(sample_dir / 'normal.png'),
+ load_depth16(sample_dir / 'depth.png'),
+ load_gray(sample_dir / 'matid.png'),
+ load_gray(sample_dir / 'shadow.png'),
+ load_gray(sample_dir / 'transp.png'),
+ )
+
+
+def load_simple(image_path: Path) -> np.ndarray:
+ """Photo → (H,W,20) f32 with geometry channels zeroed.
+
+ normal=(0.5,0.5) is the oct-encoded "no normal" (decodes to ~(0,0,1)).
+ shadow=1.0 (fully lit), transp=0.0 (opaque).
+ """
+ albedo = load_rgb(image_path)
+ h, w = albedo.shape[:2]
+ normal = np.full((h, w, 2), 0.5, dtype=np.float32)
+ depth = np.zeros((h, w), dtype=np.float32)
+ matid = np.zeros((h, w), dtype=np.float32)
+ shadow = np.ones((h, w), dtype=np.float32)
+ transp = np.zeros((h, w), dtype=np.float32)
+ return assemble_features(albedo, normal, depth, matid, shadow, transp)
+
+
+# ---------------------------------------------------------------------------
+# Inference
+# ---------------------------------------------------------------------------
+
+def pad_to_multiple(feat: np.ndarray, m: int = 4) -> tuple:
+ """Pad (H,W,C) so H and W are multiples of m. Returns (padded, (ph, pw))."""
+ h, w = feat.shape[:2]
+ ph = (m - h % m) % m
+ pw = (m - w % m) % m
+ if ph == 0 and pw == 0:
+ return feat, (0, 0)
+ return np.pad(feat, ((0, ph), (0, pw), (0, 0))), (ph, pw)
+
+
+def run_identity_film(model: CNNv3, feat: torch.Tensor) -> torch.Tensor:
+ """Forward with identity FiLM (γ=1, β=0). Matches C++ cnn_test default."""
+ c0, c1 = model.enc_channels
+ B = feat.shape[0]
+ dev = feat.device
+
+ skip0 = F.relu(model.enc0(feat))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(model.enc1(x))
+
+ x = F.relu(model.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(model.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ))
+
+ x = F.relu(model.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Output helpers
+# ---------------------------------------------------------------------------
+
+def save_png(path: Path, out: np.ndarray) -> None:
+ """Save (H,W,4) f32 [0,1] RGBA as PNG."""
+ rgba8 = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
+ Image.fromarray(rgba8, 'RGBA').save(path)
+
+
+def print_debug_hex(out: np.ndarray, n: int = 8) -> None:
+ """Print first n pixels as hex RGBA + float values."""
+ flat = out.reshape(-1, 4)
+ for i in range(min(n, flat.shape[0])):
+ r, g, b, a = flat[i]
+ ri, gi, bi, ai = int(r*255+.5), int(g*255+.5), int(b*255+.5), int(a*255+.5)
+ print(f' [{i}] 0x{ri:02X}{gi:02X}{bi:02X}{ai:02X}'
+ f' ({r:.4f} {g:.4f} {b:.4f} {a:.4f})')
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='CNN v3 PyTorch inference')
+ p.add_argument('input', help='Input PNG or sample directory')
+ p.add_argument('output', help='Output PNG')
+ p.add_argument('--checkpoint', '-c', metavar='CKPT',
+ help='Path to .pth checkpoint (auto-finds latest if omitted)')
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels (default: 4,8 — must match checkpoint)')
+ p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5,
+ help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)')
+ p.add_argument('--identity-film', action='store_true',
+ help='Bypass FiLM MLP, use γ=1 β=0 (matches C++ cnn_test default)')
+ p.add_argument('--blend', type=float, default=1.0,
+ help='Blend with input albedo: 0=input 1=CNN (default 1.0)')
+ p.add_argument('--debug-hex', action='store_true',
+ help='Print first 8 output pixels as hex')
+ args = p.parse_args()
+
+ # --- Feature loading ---
+ inp = Path(args.input)
+ if inp.is_dir():
+ print(f'Mode: full ({inp})')
+ feat = load_sample_dir(inp)
+ albedo_rgb = load_rgb(inp / 'albedo.png')
+ else:
+ print(f'Mode: simple ({inp})')
+ feat = load_simple(inp)
+ albedo_rgb = load_rgb(inp)
+ orig_h, orig_w = feat.shape[:2]
+
+ feat_padded, (ph, pw) = pad_to_multiple(feat, 4)
+ H, W = feat_padded.shape[:2]
+ if ph or pw:
+ print(f'Padded {orig_w}×{orig_h} → {W}×{H}')
+ else:
+ print(f'Resolution: {W}×{H}')
+
+ # --- Load checkpoint ---
+ if args.checkpoint:
+ ckpt_path = Path(args.checkpoint)
+ else:
+ ckpts = sorted(Path('checkpoints').glob('checkpoint_epoch_*.pth'),
+ key=lambda f: int(f.stem.split('_')[-1]))
+ if not ckpts:
+ print('Error: no checkpoint found; use --checkpoint', file=sys.stderr)
+ sys.exit(1)
+ ckpt_path = ckpts[-1]
+ print(f'Checkpoint: {ckpt_path}')
+
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
+ cfg = ckpt.get('config', {})
+ enc_channels = cfg.get('enc_channels', [int(c) for c in args.enc_channels.split(',')])
+ film_cond_dim = cfg.get('film_cond_dim', 5)
+ print(f'Architecture: enc={enc_channels} film_cond_dim={film_cond_dim}')
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
+ model.load_state_dict(ckpt['model_state_dict'])
+ model.eval()
+
+ # --- Inference ---
+ feat_t = torch.from_numpy(feat_padded).permute(2, 0, 1).unsqueeze(0) # (1,20,H,W)
+ cond_t = torch.tensor([args.cond], dtype=torch.float32) # (1,5)
+
+ with torch.no_grad():
+ if args.identity_film:
+ print('FiLM: identity (γ=1, β=0)')
+ out_t = run_identity_film(model, feat_t)
+ else:
+ print(f'FiLM cond: {args.cond}')
+ out_t = model(feat_t, cond_t)
+
+ # (1,4,H,W) → crop padding → (orig_h, orig_w, 4)
+ out = out_t[0].permute(1, 2, 0).numpy()[:orig_h, :orig_w, :]
+
+ # Optional blend with albedo
+ if args.blend < 1.0:
+ h_in, w_in = albedo_rgb.shape[:2]
+ ab = albedo_rgb[:orig_h, :orig_w]
+ ones = np.ones((orig_h, orig_w, 1), dtype=np.float32)
+ src_rgba = np.concatenate([ab, ones], axis=-1)
+ out = src_rgba * (1.0 - args.blend) + out * args.blend
+
+ # --- Save ---
+ out_path = Path(args.output)
+ save_png(out_path, out)
+ print(f'Saved: {out_path}')
+
+ if args.debug_hex:
+ print('First 8 output pixels (RGBA):')
+ print_debug_hex(out)
+
+
+if __name__ == '__main__':
+ main()