summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py45
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py44
-rw-r--r--cnn_v3/training/gen_test_vectors.py72
-rw-r--r--cnn_v3/training/infer_cnn_v3.py219
-rw-r--r--cnn_v3/training/train_cnn_v3.py74
5 files changed, 376 insertions, 78 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..68c0798 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
def assemble_features(albedo: np.ndarray, normal: np.ndarray,
depth: np.ndarray, matid: np.ndarray,
- shadow: np.ndarray, transp: np.ndarray) -> np.ndarray:
+ shadow: np.ndarray, transp: np.ndarray,
+ prev: np.ndarray | None = None) -> np.ndarray:
"""Build (H,W,20) f32 feature tensor.
- prev set to zero (no temporal history during training).
+ prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros.
mip1/mip2 computed from albedo. depth_grad computed via finite diff.
dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow.
"""
@@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
mip1 = _upsample_nearest(pyrdown(albedo), h, w)
mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
dgrad = depth_gradient(depth)
- prev = np.zeros((h, w, 3), dtype=np.float32)
+ if prev is None:
+ prev = np.zeros((h, w, 3), dtype=np.float32)
nor3 = oct_decode(normal)
diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1))
dif = diffuse * shadow
@@ -286,7 +288,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +299,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
@@ -345,11 +350,13 @@ class CNNv3Dataset(Dataset):
shadow = load_gray(sd / 'shadow.png')
transp = load_gray(sd / 'transp.png')
h, w = albedo.shape[:2]
+ prev_path = sd / 'prev.png'
+ prev = load_rgb(prev_path) if prev_path.exists() else None
target_img = Image.open(sd / 'target.png').convert('RGBA')
if target_img.size != (w, h):
target_img = target_img.resize((w, h), Image.LANCZOS)
target = np.asarray(target_img, dtype=np.float32) / 255.0
- return albedo, normal, depth, matid, shadow, transp, target
+ return albedo, normal, depth, matid, shadow, transp, prev, target
def __getitem__(self, idx):
if self.full_image:
@@ -357,7 +364,7 @@ class CNNv3Dataset(Dataset):
else:
sample_idx = idx // self.patches_per_image
- albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
+ albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
if self.full_image:
@@ -379,6 +386,8 @@ class CNNv3Dataset(Dataset):
matid = _resize_gray(matid)
shadow = _resize_gray(shadow)
transp = _resize_gray(transp)
+ if prev is not None:
+ prev = _resize_img(prev)
target = _resize_img(target)
else:
ps = self.patch_size
@@ -395,6 +404,8 @@ class CNNv3Dataset(Dataset):
matid = matid[sl]
shadow = shadow[sl]
transp = transp[sl]
+ if prev is not None:
+ prev = prev[sl]
# Apply cached target offset (if search was enabled at init).
if self._target_offsets:
@@ -405,7 +416,7 @@ class CNNv3Dataset(Dataset):
else:
target = target[sl]
- feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
+ feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev)
if self.augment:
feat = apply_channel_dropout(feat,
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 99f3a81..78f5f25 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -15,8 +15,8 @@ Outputs
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
- Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292)
- = 1964 f16 values = 982 u32 = 3928 bytes.
+ Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
+ = 2476 f16 values = 1238 u32 = 4952 bytes.
<output_dir>/cnn_v3_film_mlp.bin
FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
@@ -31,6 +31,7 @@ Usage
"""
import argparse
+import base64
import struct
import sys
from pathlib import Path
@@ -47,13 +48,13 @@ from train_cnn_v3 import CNNv3
# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
-BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
+BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
-# = 1964
+# = 2476
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
@@ -158,13 +159,40 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
print(f"\nDone → {out}/")
+_WEIGHTS_JS_DEFAULT = Path(__file__).parent.parent / 'tools' / 'weights.js'
+
+
+def update_weights_js(weights_bin: Path, film_mlp_bin: Path,
+ js_path: Path = _WEIGHTS_JS_DEFAULT) -> None:
+ """Encode both .bin files as base64 and write cnn_v3/tools/weights.js."""
+ w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii')
+ f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii')
+ js_path.write_text(
+ "'use strict';\n"
+ "// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n"
+ f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n"
+ f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n"
+ )
+ print(f"\nweights.js → {js_path}")
+ print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)")
+ print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)")
+
+
def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
help='Output directory (default: export/)')
+ p.add_argument('--html', action='store_true',
+ help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights')
+ p.add_argument('--html-output', default=None, metavar='PATH',
+ help='Override default weights.js path (implies --html)')
args = p.parse_args()
export_weights(args.checkpoint, args.output)
+ if args.html or args.html_output:
+ out = Path(args.output)
+ js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT
+ update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path)
if __name__ == '__main__':
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 640971c..2eb889c 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4
ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724
ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296
-BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72
+BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2)
DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580
DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292
@@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
-TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check
-# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount
-# ENC0: 20*4*9 + 4 = 720+4 = 724
-# ENC1: 4*8*9 + 8 = 288+8 = 296
-# BN: 8*8*1 + 8 = 64+8 = 72
-# DEC1: 16*4*9 + 4 = 576+4 = 580
-# DEC0: 8*4*9 + 4 = 288+4 = 292
-# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck.
-# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4
-# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o<DEC1_OUT, i<DEC1_IN
-# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4
-# Total DEC1 = 580. OK that's right.
-# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064?
-# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK.
-# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964.
-# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table:
-# enc0: 20*4*9=720 +4 = 724
-# enc1: 4*8*9=288 +8 = 296
-# bottleneck: 8*8*1=64 +8 = 72
-# dec1: 16*4*9=576 +4 = 580
-# dec0: 8*4*9=288 +4 = 292
-# Total = 724+296+72+580+292 = 1964
-# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc.
-# I'll use the correct value derived from the formulas: 1964.
+TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS
+# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72)
# ---------------------------------------------------------------------------
# Helpers
@@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
def bottleneck_forward(enc1, w):
"""
- AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU
+ AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU
→ rgba32uint (f16, quarter-res). No FiLM.
enc1: (hH, hW, 8) f32 — half-res
+ Matches cnn_v3_bottleneck.wgsl exactly.
"""
hH, hW = enc1.shape[:2]
qH, qW = hH // 2, hW // 2
wo = BN_OFFSET
- # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL)
- avg = np.zeros((qH, qW, BN_IN), dtype=np.float32)
- for qy in range(qH):
- for qx in range(qW):
- s = np.zeros(BN_IN, dtype=np.float32)
- for dy in range(2):
- for dx in range(2):
- hy = min(qy * 2 + dy, hH - 1)
- hx = min(qx * 2 + dx, hW - 1)
- s += enc1[hy, hx, :]
- avg[qy, qx, :] = s * 0.25
+ def load_enc1_avg(qy, qx):
+ """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL)."""
+ if qy < 0 or qx < 0 or qy >= qH or qx >= qW:
+ return np.zeros(BN_IN, dtype=np.float32)
+ s = np.zeros(BN_IN, dtype=np.float32)
+ for dy in range(2):
+ for dx in range(2):
+ hy = min(qy * 2 + dy, hH - 1)
+ hx = min(qx * 2 + dx, hW - 1)
+ s += enc1[hy, hx, :]
+ return s * 0.25
- # 1x1 conv (no spatial loop, just channel dot-product)
+ # 3x3 conv with dilation=2 in quarter-res space
out = np.zeros((qH, qW, BN_OUT), dtype=np.float32)
for o in range(BN_OUT):
- bias = get_w(w, wo, BN_OUT * BN_IN + o)
- s = np.full((qH, qW), bias, dtype=np.float32)
- for i in range(BN_IN):
- wv = get_w(w, wo, o * BN_IN + i)
- s += wv * avg[:, :, i]
- out[:, :, o] = np.maximum(0.0, s)
+ bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o)
+ for qy in range(qH):
+ for qx in range(qW):
+ s = bias
+ for ky in range(-1, 2):
+ for kx in range(-1, 2):
+ feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2
+ ki = (ky + 1) * 3 + (kx + 1)
+ for i in range(BN_IN):
+ s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i]
+ out[qy, qx, o] = max(0.0, s)
return np.float16(out).astype(np.float32) # pack2x16float boundary
diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py
new file mode 100644
index 0000000..ca1c72a
--- /dev/null
+++ b/cnn_v3/training/infer_cnn_v3.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["torch", "numpy", "pillow", "opencv-python"]
+# ///
+"""CNN v3 PyTorch inference — compare with cnn_test (WGSL/GPU output).
+
+Simple mode (single PNG): albedo = photo, geometry channels zeroed.
+Full mode (sample dir): loads all G-buffer files via assemble_features.
+
+Usage:
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint checkpoints/ckpt.pth
+ python3 infer_cnn_v3.py sample_000/ out.png --checkpoint ckpt.pth
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --identity-film
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --cond 0.5 0.0 0.8 0.0 0.0
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+sys.path.insert(0, str(Path(__file__).parent))
+from train_cnn_v3 import CNNv3
+from cnn_v3_utils import assemble_features, load_rgb, load_rg, load_depth16, load_gray
+
+
+# ---------------------------------------------------------------------------
+# Feature loading
+# ---------------------------------------------------------------------------
+
+def load_sample_dir(sample_dir: Path) -> np.ndarray:
+ """Load all G-buffer files from a sample directory → (H,W,20) f32."""
+ return assemble_features(
+ load_rgb(sample_dir / 'albedo.png'),
+ load_rg(sample_dir / 'normal.png'),
+ load_depth16(sample_dir / 'depth.png'),
+ load_gray(sample_dir / 'matid.png'),
+ load_gray(sample_dir / 'shadow.png'),
+ load_gray(sample_dir / 'transp.png'),
+ )
+
+
+def load_simple(image_path: Path) -> np.ndarray:
+ """Photo → (H,W,20) f32 with geometry channels zeroed.
+
+ normal=(0.5,0.5) is the oct-encoded "no normal" (decodes to ~(0,0,1)).
+ shadow=1.0 (fully lit), transp=0.0 (opaque).
+ """
+ albedo = load_rgb(image_path)
+ h, w = albedo.shape[:2]
+ normal = np.full((h, w, 2), 0.5, dtype=np.float32)
+ depth = np.zeros((h, w), dtype=np.float32)
+ matid = np.zeros((h, w), dtype=np.float32)
+ shadow = np.ones((h, w), dtype=np.float32)
+ transp = np.zeros((h, w), dtype=np.float32)
+ return assemble_features(albedo, normal, depth, matid, shadow, transp)
+
+
+# ---------------------------------------------------------------------------
+# Inference
+# ---------------------------------------------------------------------------
+
+def pad_to_multiple(feat: np.ndarray, m: int = 4) -> tuple:
+ """Pad (H,W,C) so H and W are multiples of m. Returns (padded, (ph, pw))."""
+ h, w = feat.shape[:2]
+ ph = (m - h % m) % m
+ pw = (m - w % m) % m
+ if ph == 0 and pw == 0:
+ return feat, (0, 0)
+ return np.pad(feat, ((0, ph), (0, pw), (0, 0))), (ph, pw)
+
+
+def run_identity_film(model: CNNv3, feat: torch.Tensor) -> torch.Tensor:
+ """Forward with identity FiLM (γ=1, β=0). Matches C++ cnn_test default."""
+ c0, c1 = model.enc_channels
+ B = feat.shape[0]
+ dev = feat.device
+
+ skip0 = F.relu(model.enc0(feat))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(model.enc1(x))
+
+ x = F.relu(model.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(model.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ))
+
+ x = F.relu(model.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Output helpers
+# ---------------------------------------------------------------------------
+
+def save_png(path: Path, out: np.ndarray) -> None:
+ """Save (H,W,4) f32 [0,1] RGBA as PNG."""
+ rgba8 = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
+ Image.fromarray(rgba8, 'RGBA').save(path)
+
+
+def print_debug_hex(out: np.ndarray, n: int = 8) -> None:
+ """Print first n pixels as hex RGBA + float values."""
+ flat = out.reshape(-1, 4)
+ for i in range(min(n, flat.shape[0])):
+ r, g, b, a = flat[i]
+ ri, gi, bi, ai = int(r*255+.5), int(g*255+.5), int(b*255+.5), int(a*255+.5)
+ print(f' [{i}] 0x{ri:02X}{gi:02X}{bi:02X}{ai:02X}'
+ f' ({r:.4f} {g:.4f} {b:.4f} {a:.4f})')
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='CNN v3 PyTorch inference')
+ p.add_argument('input', help='Input PNG or sample directory')
+ p.add_argument('output', help='Output PNG')
+ p.add_argument('--checkpoint', '-c', metavar='CKPT',
+ help='Path to .pth checkpoint (auto-finds latest if omitted)')
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels (default: 4,8 — must match checkpoint)')
+ p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5,
+ help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)')
+ p.add_argument('--identity-film', action='store_true',
+ help='Bypass FiLM MLP, use γ=1 β=0 (matches C++ cnn_test default)')
+ p.add_argument('--blend', type=float, default=1.0,
+ help='Blend with input albedo: 0=input 1=CNN (default 1.0)')
+ p.add_argument('--debug-hex', action='store_true',
+ help='Print first 8 output pixels as hex')
+ args = p.parse_args()
+
+ # --- Feature loading ---
+ inp = Path(args.input)
+ if inp.is_dir():
+ print(f'Mode: full ({inp})')
+ feat = load_sample_dir(inp)
+ albedo_rgb = load_rgb(inp / 'albedo.png')
+ else:
+ print(f'Mode: simple ({inp})')
+ feat = load_simple(inp)
+ albedo_rgb = load_rgb(inp)
+ orig_h, orig_w = feat.shape[:2]
+
+ feat_padded, (ph, pw) = pad_to_multiple(feat, 4)
+ H, W = feat_padded.shape[:2]
+ if ph or pw:
+ print(f'Padded {orig_w}×{orig_h} → {W}×{H}')
+ else:
+ print(f'Resolution: {W}×{H}')
+
+ # --- Load checkpoint ---
+ if args.checkpoint:
+ ckpt_path = Path(args.checkpoint)
+ else:
+ ckpts = sorted(Path('checkpoints').glob('checkpoint_epoch_*.pth'),
+ key=lambda f: int(f.stem.split('_')[-1]))
+ if not ckpts:
+ print('Error: no checkpoint found; use --checkpoint', file=sys.stderr)
+ sys.exit(1)
+ ckpt_path = ckpts[-1]
+ print(f'Checkpoint: {ckpt_path}')
+
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
+ cfg = ckpt.get('config', {})
+ enc_channels = cfg.get('enc_channels', [int(c) for c in args.enc_channels.split(',')])
+ film_cond_dim = cfg.get('film_cond_dim', 5)
+ print(f'Architecture: enc={enc_channels} film_cond_dim={film_cond_dim}')
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
+ model.load_state_dict(ckpt['model_state_dict'])
+ model.eval()
+
+ # --- Inference ---
+ feat_t = torch.from_numpy(feat_padded).permute(2, 0, 1).unsqueeze(0) # (1,20,H,W)
+ cond_t = torch.tensor([args.cond], dtype=torch.float32) # (1,5)
+
+ with torch.no_grad():
+ if args.identity_film:
+ print('FiLM: identity (γ=1, β=0)')
+ out_t = run_identity_film(model, feat_t)
+ else:
+ print(f'FiLM cond: {args.cond}')
+ out_t = model(feat_t, cond_t)
+
+ # (1,4,H,W) → crop padding → (orig_h, orig_w, 4)
+ out = out_t[0].permute(1, 2, 0).numpy()[:orig_h, :orig_w, :]
+
+ # Optional blend with albedo
+ if args.blend < 1.0:
+ h_in, w_in = albedo_rgb.shape[:2]
+ ab = albedo_rgb[:orig_h, :orig_w]
+ ones = np.ones((orig_h, orig_w, 1), dtype=np.float32)
+ src_rgba = np.concatenate([ab, ones], axis=-1)
+ out = src_rgba * (1.0 - args.blend) + out * args.blend
+
+ # --- Save ---
+ out_path = Path(args.output)
+ save_png(out_path, out)
+ print(f'Saved: {out_path}')
+
+ if args.debug_hex:
+ print('First 8 output pixels (RGBA):')
+ print_debug_hex(out)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index de10d6a..c790495 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -6,17 +6,21 @@
"""CNN v3 Training Script — U-Net + FiLM
Architecture:
- enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
- enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
- bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
- dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
- dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
output sigmoid → RGBA
FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
-Weight budget: ~5.4 KB f16 (fits ≤6 KB target)
+Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target)
+
+Training improvements:
+ --edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
+ --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
"""
import argparse
@@ -56,7 +60,7 @@ class CNNv3(nn.Module):
self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
- self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2)
self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
@@ -96,6 +100,24 @@ class CNNv3(nn.Module):
# ---------------------------------------------------------------------------
+# Loss
+# ---------------------------------------------------------------------------
+
+def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Gradient loss via Sobel filters. No VGG dependency.
+ pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device."""
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ B, C, H, W = pred.shape
+ p = pred.view(B * C, 1, H, W)
+ t = target.view(B * C, 1, H, W)
+ return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) +
+ F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1)))
+
+
+# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
@@ -104,6 +126,10 @@ def train(args):
enc_channels = [int(c) for c in args.enc_channels.split(',')]
print(f"Device: {device}")
+ if args.single_sample:
+ args.full_image = True
+ args.batch_size = 1
+
dataset = CNNv3Dataset(
dataset_dir=args.input,
input_mode=args.input_mode,
@@ -115,6 +141,7 @@ def train(args):
detector=args.detector,
augment=True,
patch_search_window=args.patch_search_window,
+ single_sample=args.single_sample,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -124,11 +151,20 @@ def train(args):
print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first.
+ film_warmup = args.film_warmup_epochs
+ if film_warmup > 0:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = False
+ print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)")
+
+ optimizer = torch.optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
criterion = nn.MSELoss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1
+ film_unfrozen = (film_warmup == 0)
if args.resume:
ckpt_path = Path(args.resume)
@@ -163,6 +199,15 @@ def train(args):
for epoch in range(start_epoch, args.epochs + 1):
if interrupted:
break
+
+ # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR.
+ if not film_unfrozen and epoch > film_warmup:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = True
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1)
+ film_unfrozen = True
+ print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})")
+
model.train()
epoch_loss = 0.0
n_batches = 0
@@ -172,7 +217,10 @@ def train(args):
break
feat, cond, target = feat.to(device), cond.to(device), target.to(device)
optimizer.zero_grad()
- loss = criterion(model(feat, cond), target)
+ pred = model(feat, cond)
+ loss = criterion(pred, target)
+ if args.edge_loss_weight > 0.0:
+ loss = loss + args.edge_loss_weight * sobel_loss(pred, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
@@ -210,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args):
'enc_channels': [int(c) for c in args.enc_channels.split(',')],
'film_cond_dim': args.film_cond_dim,
'input_mode': args.input_mode,
+ 'edge_loss_weight': args.edge_loss_weight,
+ 'film_warmup_epochs': args.film_warmup_epochs,
},
}
@@ -222,6 +272,8 @@ def main():
p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
# Dataset
+ p.add_argument('--single-sample', default='', metavar='DIR',
+ help='Train on a single sample directory; implies --full-image and --batch-size 1')
p.add_argument('--input', default='training/dataset',
help='Dataset root (contains full/ or simple/ subdirs)')
p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
@@ -259,6 +311,10 @@ def main():
help='Save checkpoint every N epochs (0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
+ p.add_argument('--edge-loss-weight', type=float, default=0.1,
+ help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)')
+ p.add_argument('--film-warmup-epochs', type=int, default=50,
+ help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)')
train(p.parse_args())