diff options
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 45 | ||||
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 44 | ||||
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 72 | ||||
| -rw-r--r-- | cnn_v3/training/infer_cnn_v3.py | 219 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 74 |
5 files changed, 376 insertions, 78 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index bef4091..68c0798 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: def assemble_features(albedo: np.ndarray, normal: np.ndarray, depth: np.ndarray, matid: np.ndarray, - shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: + shadow: np.ndarray, transp: np.ndarray, + prev: np.ndarray | None = None) -> np.ndarray: """Build (H,W,20) f32 feature tensor. - prev set to zero (no temporal history during training). + prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros. mip1/mip2 computed from albedo. depth_grad computed via finite diff. dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow. """ @@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray, mip1 = _upsample_nearest(pyrdown(albedo), h, w) mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) dgrad = depth_gradient(depth) - prev = np.zeros((h, w, 3), dtype=np.float32) + if prev is None: + prev = np.zeros((h, w, 3), dtype=np.float32) nor3 = oct_decode(normal) diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1)) dif = diffuse * shadow @@ -286,7 +288,8 @@ class CNNv3Dataset(Dataset): channel_dropout_p: float = 0.3, detector: str = 'harris', augment: bool = True, - patch_search_window: int = 0): + patch_search_window: int = 0, + single_sample: str = ''): self.patch_size = patch_size self.patches_per_image = patches_per_image self.image_size = image_size @@ -296,16 +299,18 @@ class CNNv3Dataset(Dataset): self.augment = augment self.patch_search_window = patch_search_window - root = Path(dataset_dir) - subdir = 'full' if input_mode == 'full' else 'simple' - search_dir = root / subdir - if not search_dir.exists(): - search_dir = root - - self.samples = sorted([ - d for d in search_dir.iterdir() - if d.is_dir() and (d / 'albedo.png').exists() - ]) + if single_sample: + self.samples = [Path(single_sample)] + else: + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) if not self.samples: raise RuntimeError(f"No samples found in {search_dir}") @@ -345,11 +350,13 @@ class CNNv3Dataset(Dataset): shadow = load_gray(sd / 'shadow.png') transp = load_gray(sd / 'transp.png') h, w = albedo.shape[:2] + prev_path = sd / 'prev.png' + prev = load_rgb(prev_path) if prev_path.exists() else None target_img = Image.open(sd / 'target.png').convert('RGBA') if target_img.size != (w, h): target_img = target_img.resize((w, h), Image.LANCZOS) target = np.asarray(target_img, dtype=np.float32) / 255.0 - return albedo, normal, depth, matid, shadow, transp, target + return albedo, normal, depth, matid, shadow, transp, prev, target def __getitem__(self, idx): if self.full_image: @@ -357,7 +364,7 @@ class CNNv3Dataset(Dataset): else: sample_idx = idx // self.patches_per_image - albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx] + albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx] h, w = albedo.shape[:2] if self.full_image: @@ -379,6 +386,8 @@ class CNNv3Dataset(Dataset): matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) + if prev is not None: + prev = _resize_img(prev) target = _resize_img(target) else: ps = self.patch_size @@ -395,6 +404,8 @@ class CNNv3Dataset(Dataset): matid = matid[sl] shadow = shadow[sl] transp = transp[sl] + if prev is not None: + prev = prev[sl] # Apply cached target offset (if search was enabled at init). if self._target_offsets: @@ -405,7 +416,7 @@ class CNNv3Dataset(Dataset): else: target = target[sl] - feat = assemble_features(albedo, normal, depth, matid, shadow, transp) + feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev) if self.augment: feat = apply_channel_dropout(feat, diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index 99f3a81..78f5f25 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -15,8 +15,8 @@ Outputs <output_dir>/cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). - Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) - = 1964 f16 values = 982 u32 = 3928 bytes. + Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292) + = 2476 f16 values = 1238 u32 = 4952 bytes. <output_dir>/cnn_v3_film_mlp.bin FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). @@ -31,6 +31,7 @@ Usage """ import argparse +import base64 import struct import sys from pathlib import Path @@ -47,13 +48,13 @@ from train_cnn_v3 import CNNv3 # cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) # cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS -# = 1964 +# = 2476 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: @@ -158,13 +159,40 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: print(f"\nDone → {out}/") +_WEIGHTS_JS_DEFAULT = Path(__file__).parent.parent / 'tools' / 'weights.js' + + +def update_weights_js(weights_bin: Path, film_mlp_bin: Path, + js_path: Path = _WEIGHTS_JS_DEFAULT) -> None: + """Encode both .bin files as base64 and write cnn_v3/tools/weights.js.""" + w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii') + f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii') + js_path.write_text( + "'use strict';\n" + "// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n" + f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n" + f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n" + ) + print(f"\nweights.js → {js_path}") + print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)") + print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)") + + def main() -> None: p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin') p.add_argument('checkpoint', help='Path to .pth checkpoint file') p.add_argument('--output', default='export', help='Output directory (default: export/)') + p.add_argument('--html', action='store_true', + help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights') + p.add_argument('--html-output', default=None, metavar='PATH', + help='Override default weights.js path (implies --html)') args = p.parse_args() export_weights(args.checkpoint, args.output) + if args.html or args.html_output: + out = Path(args.output) + js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT + update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path) if __name__ == '__main__': diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 640971c..2eb889c 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4 ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 -BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72 +BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2) DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 @@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS -TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check -# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount -# ENC0: 20*4*9 + 4 = 720+4 = 724 -# ENC1: 4*8*9 + 8 = 288+8 = 296 -# BN: 8*8*1 + 8 = 64+8 = 72 -# DEC1: 16*4*9 + 4 = 576+4 = 580 -# DEC0: 8*4*9 + 4 = 288+4 = 292 -# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck. -# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4 -# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o<DEC1_OUT, i<DEC1_IN -# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4 -# Total DEC1 = 580. OK that's right. -# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064? -# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK. -# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964. -# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table: -# enc0: 20*4*9=720 +4 = 724 -# enc1: 4*8*9=288 +8 = 296 -# bottleneck: 8*8*1=64 +8 = 72 -# dec1: 16*4*9=576 +4 = 580 -# dec0: 8*4*9=288 +4 = 292 -# Total = 724+296+72+580+292 = 1964 -# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc. -# I'll use the correct value derived from the formulas: 1964. +TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS +# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72) # --------------------------------------------------------------------------- # Helpers @@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): def bottleneck_forward(enc1, w): """ - AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU + AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU → rgba32uint (f16, quarter-res). No FiLM. enc1: (hH, hW, 8) f32 — half-res + Matches cnn_v3_bottleneck.wgsl exactly. """ hH, hW = enc1.shape[:2] qH, qW = hH // 2, hW // 2 wo = BN_OFFSET - # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL) - avg = np.zeros((qH, qW, BN_IN), dtype=np.float32) - for qy in range(qH): - for qx in range(qW): - s = np.zeros(BN_IN, dtype=np.float32) - for dy in range(2): - for dx in range(2): - hy = min(qy * 2 + dy, hH - 1) - hx = min(qx * 2 + dx, hW - 1) - s += enc1[hy, hx, :] - avg[qy, qx, :] = s * 0.25 + def load_enc1_avg(qy, qx): + """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL).""" + if qy < 0 or qx < 0 or qy >= qH or qx >= qW: + return np.zeros(BN_IN, dtype=np.float32) + s = np.zeros(BN_IN, dtype=np.float32) + for dy in range(2): + for dx in range(2): + hy = min(qy * 2 + dy, hH - 1) + hx = min(qx * 2 + dx, hW - 1) + s += enc1[hy, hx, :] + return s * 0.25 - # 1x1 conv (no spatial loop, just channel dot-product) + # 3x3 conv with dilation=2 in quarter-res space out = np.zeros((qH, qW, BN_OUT), dtype=np.float32) for o in range(BN_OUT): - bias = get_w(w, wo, BN_OUT * BN_IN + o) - s = np.full((qH, qW), bias, dtype=np.float32) - for i in range(BN_IN): - wv = get_w(w, wo, o * BN_IN + i) - s += wv * avg[:, :, i] - out[:, :, o] = np.maximum(0.0, s) + bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o) + for qy in range(qH): + for qx in range(qW): + s = bias + for ky in range(-1, 2): + for kx in range(-1, 2): + feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2 + ki = (ky + 1) * 3 + (kx + 1) + for i in range(BN_IN): + s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i] + out[qy, qx, o] = max(0.0, s) return np.float16(out).astype(np.float32) # pack2x16float boundary diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py new file mode 100644 index 0000000..ca1c72a --- /dev/null +++ b/cnn_v3/training/infer_cnn_v3.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["torch", "numpy", "pillow", "opencv-python"] +# /// +"""CNN v3 PyTorch inference — compare with cnn_test (WGSL/GPU output). + +Simple mode (single PNG): albedo = photo, geometry channels zeroed. +Full mode (sample dir): loads all G-buffer files via assemble_features. + +Usage: + python3 infer_cnn_v3.py photo.png out.png --checkpoint checkpoints/ckpt.pth + python3 infer_cnn_v3.py sample_000/ out.png --checkpoint ckpt.pth + python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --identity-film + python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --cond 0.5 0.0 0.8 0.0 0.0 +""" + +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +sys.path.insert(0, str(Path(__file__).parent)) +from train_cnn_v3 import CNNv3 +from cnn_v3_utils import assemble_features, load_rgb, load_rg, load_depth16, load_gray + + +# --------------------------------------------------------------------------- +# Feature loading +# --------------------------------------------------------------------------- + +def load_sample_dir(sample_dir: Path) -> np.ndarray: + """Load all G-buffer files from a sample directory → (H,W,20) f32.""" + return assemble_features( + load_rgb(sample_dir / 'albedo.png'), + load_rg(sample_dir / 'normal.png'), + load_depth16(sample_dir / 'depth.png'), + load_gray(sample_dir / 'matid.png'), + load_gray(sample_dir / 'shadow.png'), + load_gray(sample_dir / 'transp.png'), + ) + + +def load_simple(image_path: Path) -> np.ndarray: + """Photo → (H,W,20) f32 with geometry channels zeroed. + + normal=(0.5,0.5) is the oct-encoded "no normal" (decodes to ~(0,0,1)). + shadow=1.0 (fully lit), transp=0.0 (opaque). + """ + albedo = load_rgb(image_path) + h, w = albedo.shape[:2] + normal = np.full((h, w, 2), 0.5, dtype=np.float32) + depth = np.zeros((h, w), dtype=np.float32) + matid = np.zeros((h, w), dtype=np.float32) + shadow = np.ones((h, w), dtype=np.float32) + transp = np.zeros((h, w), dtype=np.float32) + return assemble_features(albedo, normal, depth, matid, shadow, transp) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- + +def pad_to_multiple(feat: np.ndarray, m: int = 4) -> tuple: + """Pad (H,W,C) so H and W are multiples of m. Returns (padded, (ph, pw)).""" + h, w = feat.shape[:2] + ph = (m - h % m) % m + pw = (m - w % m) % m + if ph == 0 and pw == 0: + return feat, (0, 0) + return np.pad(feat, ((0, ph), (0, pw), (0, 0))), (ph, pw) + + +def run_identity_film(model: CNNv3, feat: torch.Tensor) -> torch.Tensor: + """Forward with identity FiLM (γ=1, β=0). Matches C++ cnn_test default.""" + c0, c1 = model.enc_channels + B = feat.shape[0] + dev = feat.device + + skip0 = F.relu(model.enc0(feat)) + + x = F.avg_pool2d(skip0, 2) + skip1 = F.relu(model.enc1(x)) + + x = F.relu(model.bottleneck(F.avg_pool2d(skip1, 2))) + + x = F.relu(model.dec1( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) + )) + + x = F.relu(model.dec0( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) + )) + + return torch.sigmoid(x) + + +# --------------------------------------------------------------------------- +# Output helpers +# --------------------------------------------------------------------------- + +def save_png(path: Path, out: np.ndarray) -> None: + """Save (H,W,4) f32 [0,1] RGBA as PNG.""" + rgba8 = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + Image.fromarray(rgba8, 'RGBA').save(path) + + +def print_debug_hex(out: np.ndarray, n: int = 8) -> None: + """Print first n pixels as hex RGBA + float values.""" + flat = out.reshape(-1, 4) + for i in range(min(n, flat.shape[0])): + r, g, b, a = flat[i] + ri, gi, bi, ai = int(r*255+.5), int(g*255+.5), int(b*255+.5), int(a*255+.5) + print(f' [{i}] 0x{ri:02X}{gi:02X}{bi:02X}{ai:02X}' + f' ({r:.4f} {g:.4f} {b:.4f} {a:.4f})') + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser(description='CNN v3 PyTorch inference') + p.add_argument('input', help='Input PNG or sample directory') + p.add_argument('output', help='Output PNG') + p.add_argument('--checkpoint', '-c', metavar='CKPT', + help='Path to .pth checkpoint (auto-finds latest if omitted)') + p.add_argument('--enc-channels', default='4,8', + help='Encoder channels (default: 4,8 — must match checkpoint)') + p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5, + help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)') + p.add_argument('--identity-film', action='store_true', + help='Bypass FiLM MLP, use γ=1 β=0 (matches C++ cnn_test default)') + p.add_argument('--blend', type=float, default=1.0, + help='Blend with input albedo: 0=input 1=CNN (default 1.0)') + p.add_argument('--debug-hex', action='store_true', + help='Print first 8 output pixels as hex') + args = p.parse_args() + + # --- Feature loading --- + inp = Path(args.input) + if inp.is_dir(): + print(f'Mode: full ({inp})') + feat = load_sample_dir(inp) + albedo_rgb = load_rgb(inp / 'albedo.png') + else: + print(f'Mode: simple ({inp})') + feat = load_simple(inp) + albedo_rgb = load_rgb(inp) + orig_h, orig_w = feat.shape[:2] + + feat_padded, (ph, pw) = pad_to_multiple(feat, 4) + H, W = feat_padded.shape[:2] + if ph or pw: + print(f'Padded {orig_w}×{orig_h} → {W}×{H}') + else: + print(f'Resolution: {W}×{H}') + + # --- Load checkpoint --- + if args.checkpoint: + ckpt_path = Path(args.checkpoint) + else: + ckpts = sorted(Path('checkpoints').glob('checkpoint_epoch_*.pth'), + key=lambda f: int(f.stem.split('_')[-1])) + if not ckpts: + print('Error: no checkpoint found; use --checkpoint', file=sys.stderr) + sys.exit(1) + ckpt_path = ckpts[-1] + print(f'Checkpoint: {ckpt_path}') + + ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) + cfg = ckpt.get('config', {}) + enc_channels = cfg.get('enc_channels', [int(c) for c in args.enc_channels.split(',')]) + film_cond_dim = cfg.get('film_cond_dim', 5) + print(f'Architecture: enc={enc_channels} film_cond_dim={film_cond_dim}') + + model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) + model.load_state_dict(ckpt['model_state_dict']) + model.eval() + + # --- Inference --- + feat_t = torch.from_numpy(feat_padded).permute(2, 0, 1).unsqueeze(0) # (1,20,H,W) + cond_t = torch.tensor([args.cond], dtype=torch.float32) # (1,5) + + with torch.no_grad(): + if args.identity_film: + print('FiLM: identity (γ=1, β=0)') + out_t = run_identity_film(model, feat_t) + else: + print(f'FiLM cond: {args.cond}') + out_t = model(feat_t, cond_t) + + # (1,4,H,W) → crop padding → (orig_h, orig_w, 4) + out = out_t[0].permute(1, 2, 0).numpy()[:orig_h, :orig_w, :] + + # Optional blend with albedo + if args.blend < 1.0: + h_in, w_in = albedo_rgb.shape[:2] + ab = albedo_rgb[:orig_h, :orig_w] + ones = np.ones((orig_h, orig_w, 1), dtype=np.float32) + src_rgba = np.concatenate([ab, ones], axis=-1) + out = src_rgba * (1.0 - args.blend) + out * args.blend + + # --- Save --- + out_path = Path(args.output) + save_png(out_path, out) + print(f'Saved: {out_path}') + + if args.debug_hex: + print('First 8 output pixels (RGBA):') + print_debug_hex(out) + + +if __name__ == '__main__': + main() diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index de10d6a..c790495 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -6,17 +6,21 @@ """CNN v3 Training Script — U-Net + FiLM Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W output sigmoid → RGBA FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) -Weight budget: ~5.4 KB f16 (fits ≤6 KB target) +Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) + +Training improvements: + --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) + --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) """ import argparse @@ -56,7 +60,7 @@ class CNNv3(nn.Module): self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) - self.bottleneck = nn.Conv2d(c1, c1, 1) + self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2) self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 @@ -96,6 +100,24 @@ class CNNv3(nn.Module): # --------------------------------------------------------------------------- +# Loss +# --------------------------------------------------------------------------- + +def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Gradient loss via Sobel filters. No VGG dependency. + pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device.""" + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + B, C, H, W = pred.shape + p = pred.view(B * C, 1, H, W) + t = target.view(B * C, 1, H, W) + return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) + + F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1))) + + +# --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- @@ -104,6 +126,10 @@ def train(args): enc_channels = [int(c) for c in args.enc_channels.split(',')] print(f"Device: {device}") + if args.single_sample: + args.full_image = True + args.batch_size = 1 + dataset = CNNv3Dataset( dataset_dir=args.input, input_mode=args.input_mode, @@ -115,6 +141,7 @@ def train(args): detector=args.detector, augment=True, patch_search_window=args.patch_search_window, + single_sample=args.single_sample, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) @@ -124,11 +151,20 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first. + film_warmup = args.film_warmup_epochs + if film_warmup > 0: + for p in model.film_mlp.parameters(): + p.requires_grad = False + print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)") + + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) criterion = nn.MSELoss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 + film_unfrozen = (film_warmup == 0) if args.resume: ckpt_path = Path(args.resume) @@ -163,6 +199,15 @@ def train(args): for epoch in range(start_epoch, args.epochs + 1): if interrupted: break + + # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR. + if not film_unfrozen and epoch > film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1) + film_unfrozen = True + print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})") + model.train() epoch_loss = 0.0 n_batches = 0 @@ -172,7 +217,10 @@ def train(args): break feat, cond, target = feat.to(device), cond.to(device), target.to(device) optimizer.zero_grad() - loss = criterion(model(feat, cond), target) + pred = model(feat, cond) + loss = criterion(pred, target) + if args.edge_loss_weight > 0.0: + loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() @@ -210,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args): 'enc_channels': [int(c) for c in args.enc_channels.split(',')], 'film_cond_dim': args.film_cond_dim, 'input_mode': args.input_mode, + 'edge_loss_weight': args.edge_loss_weight, + 'film_warmup_epochs': args.film_warmup_epochs, }, } @@ -222,6 +272,8 @@ def main(): p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') # Dataset + p.add_argument('--single-sample', default='', metavar='DIR', + help='Train on a single sample directory; implies --full-image and --batch-size 1') p.add_argument('--input', default='training/dataset', help='Dataset root (contains full/ or simple/ subdirs)') p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], @@ -259,6 +311,10 @@ def main(): help='Save checkpoint every N epochs (0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') + p.add_argument('--edge-loss-weight', type=float, default=0.1, + help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--film-warmup-epochs', type=int, default=50, + help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') train(p.parse_args()) |
