#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["torch", "numpy", "pillow", "opencv-python"] # /// """CNN v3 PyTorch inference — compare with cnn_test (WGSL/GPU output). Simple mode (single PNG): albedo = photo, geometry channels zeroed. Full mode (sample dir): loads all G-buffer files via assemble_features. Usage: python3 infer_cnn_v3.py photo.png out.png --checkpoint checkpoints/ckpt.pth python3 infer_cnn_v3.py sample_000/ out.png --checkpoint ckpt.pth python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --identity-film python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --cond 0.5 0.0 0.8 0.0 0.0 """ import argparse import sys from pathlib import Path import numpy as np import torch import torch.nn.functional as F from PIL import Image sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 from cnn_v3_utils import assemble_features, load_rgb, load_rg, load_depth16, load_gray # --------------------------------------------------------------------------- # Feature loading # --------------------------------------------------------------------------- def load_sample_dir(sample_dir: Path) -> np.ndarray: """Load all G-buffer files from a sample directory → (H,W,20) f32.""" return assemble_features( load_rgb(sample_dir / 'albedo.png'), load_rg(sample_dir / 'normal.png'), load_depth16(sample_dir / 'depth.png'), load_gray(sample_dir / 'matid.png'), load_gray(sample_dir / 'shadow.png'), load_gray(sample_dir / 'transp.png'), ) def load_simple(image_path: Path) -> np.ndarray: """Photo → (H,W,20) f32 with geometry channels zeroed. normal=(0.5,0.5) is the oct-encoded "no normal" (decodes to ~(0,0,1)). shadow=1.0 (fully lit), transp=0.0 (opaque). """ albedo = load_rgb(image_path) h, w = albedo.shape[:2] normal = np.full((h, w, 2), 0.5, dtype=np.float32) depth = np.zeros((h, w), dtype=np.float32) matid = np.zeros((h, w), dtype=np.float32) shadow = np.ones((h, w), dtype=np.float32) transp = np.zeros((h, w), dtype=np.float32) return assemble_features(albedo, normal, depth, matid, shadow, transp) # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def pad_to_multiple(feat: np.ndarray, m: int = 4) -> tuple: """Pad (H,W,C) so H and W are multiples of m. Returns (padded, (ph, pw)).""" h, w = feat.shape[:2] ph = (m - h % m) % m pw = (m - w % m) % m if ph == 0 and pw == 0: return feat, (0, 0) return np.pad(feat, ((0, ph), (0, pw), (0, 0))), (ph, pw) def run_identity_film(model: CNNv3, feat: torch.Tensor) -> torch.Tensor: """Forward with identity FiLM (γ=1, β=0). Matches C++ cnn_test default.""" c0, c1 = model.enc_channels B = feat.shape[0] dev = feat.device skip0 = F.relu(model.enc0(feat)) x = F.avg_pool2d(skip0, 2) skip1 = F.relu(model.enc1(x)) x = F.relu(model.bottleneck(F.avg_pool2d(skip1, 2))) x = F.relu(model.dec1( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) )) x = F.relu(model.dec0( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) )) return torch.sigmoid(x) # --------------------------------------------------------------------------- # Output helpers # --------------------------------------------------------------------------- def save_png(path: Path, out: np.ndarray) -> None: """Save (H,W,4) f32 [0,1] RGBA as PNG.""" rgba8 = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) Image.fromarray(rgba8, 'RGBA').save(path) def print_debug_hex(out: np.ndarray, n: int = 8) -> None: """Print first n pixels as hex RGBA + float values.""" flat = out.reshape(-1, 4) for i in range(min(n, flat.shape[0])): r, g, b, a = flat[i] ri, gi, bi, ai = int(r*255+.5), int(g*255+.5), int(b*255+.5), int(a*255+.5) print(f' [{i}] 0x{ri:02X}{gi:02X}{bi:02X}{ai:02X}' f' ({r:.4f} {g:.4f} {b:.4f} {a:.4f})') # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description='CNN v3 PyTorch inference') p.add_argument('input', help='Input PNG or sample directory') p.add_argument('output', help='Output PNG') p.add_argument('--checkpoint', '-c', metavar='CKPT', help='Path to .pth checkpoint (auto-finds latest if omitted)') p.add_argument('--enc-channels', default='4,8', help='Encoder channels (default: 4,8 — must match checkpoint)') p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5, help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)') p.add_argument('--identity-film', action='store_true', help='Bypass FiLM MLP, use γ=1 β=0 (matches C++ cnn_test default)') p.add_argument('--blend', type=float, default=1.0, help='Blend with input albedo: 0=input 1=CNN (default 1.0)') p.add_argument('--debug-hex', action='store_true', help='Print first 8 output pixels as hex') args = p.parse_args() # --- Feature loading --- inp = Path(args.input) if inp.is_dir(): print(f'Mode: full ({inp})') feat = load_sample_dir(inp) albedo_rgb = load_rgb(inp / 'albedo.png') else: print(f'Mode: simple ({inp})') feat = load_simple(inp) albedo_rgb = load_rgb(inp) orig_h, orig_w = feat.shape[:2] feat_padded, (ph, pw) = pad_to_multiple(feat, 4) H, W = feat_padded.shape[:2] if ph or pw: print(f'Padded {orig_w}×{orig_h} → {W}×{H}') else: print(f'Resolution: {W}×{H}') # --- Load checkpoint --- if args.checkpoint: ckpt_path = Path(args.checkpoint) else: ckpts = sorted(Path('checkpoints').glob('checkpoint_epoch_*.pth'), key=lambda f: int(f.stem.split('_')[-1])) if not ckpts: print('Error: no checkpoint found; use --checkpoint', file=sys.stderr) sys.exit(1) ckpt_path = ckpts[-1] print(f'Checkpoint: {ckpt_path}') ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) cfg = ckpt.get('config', {}) enc_channels = cfg.get('enc_channels', [int(c) for c in args.enc_channels.split(',')]) film_cond_dim = cfg.get('film_cond_dim', 5) print(f'Architecture: enc={enc_channels} film_cond_dim={film_cond_dim}') model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) model.load_state_dict(ckpt['model_state_dict']) model.eval() # --- Inference --- feat_t = torch.from_numpy(feat_padded).permute(2, 0, 1).unsqueeze(0) # (1,20,H,W) cond_t = torch.tensor([args.cond], dtype=torch.float32) # (1,5) with torch.no_grad(): if args.identity_film: print('FiLM: identity (γ=1, β=0)') out_t = run_identity_film(model, feat_t) else: print(f'FiLM cond: {args.cond}') out_t = model(feat_t, cond_t) # (1,4,H,W) → crop padding → (orig_h, orig_w, 4) out = out_t[0].permute(1, 2, 0).numpy()[:orig_h, :orig_w, :] # Optional blend with albedo if args.blend < 1.0: h_in, w_in = albedo_rgb.shape[:2] ab = albedo_rgb[:orig_h, :orig_w] ones = np.ones((orig_h, orig_w, 1), dtype=np.float32) src_rgba = np.concatenate([ab, ones], axis=-1) out = src_rgba * (1.0 - args.blend) + out * args.blend # --- Save --- out_path = Path(args.output) save_png(out_path, out) print(f'Saved: {out_path}') if args.debug_hex: print('First 8 output pixels (RGBA):') print_debug_hex(out) if __name__ == '__main__': main()