#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["torch", "torchvision", "numpy", "pillow", "opencv-python"] # /// """CNN v3 Training Script — U-Net + FiLM Architecture: enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W output sigmoid → RGBA FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) Training improvements: --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) """ import argparse import signal import time from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from cnn_v3_utils import CNNv3Dataset, N_FEATURES # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W.""" return gamma[:, :, None, None] * x + beta[:, :, None, None] class CNNv3(nn.Module): """U-Net + FiLM conditioning. enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] film_cond_dim: FiLM conditioning input size, default 5 """ def __init__(self, enc_channels=None, film_cond_dim: int = 5): super().__init__() if enc_channels is None: enc_channels = [4, 8] assert len(enc_channels) == 2, "Only 2-level U-Net supported" c0, c1 = enc_channels self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2) self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0 self.film_mlp = nn.Sequential( nn.Linear(film_cond_dim, 16), nn.ReLU(), nn.Linear(16, film_out), ) self.enc_channels = enc_channels def _split_film(self, film: torch.Tensor): c0, c1 = self.enc_channels parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1) return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0 def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]""" g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond)) skip0 = F.relu(film_apply(self.enc0(feat), g0, b0)) x = F.avg_pool2d(skip0, 2) skip1 = F.relu(film_apply(self.enc1(x), g1, b1)) x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2))) x = F.relu(film_apply(self.dec1( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) ), gd1, bd1)) x = F.relu(film_apply(self.dec0( torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) ), gd0, bd0)) return torch.sigmoid(x) # --------------------------------------------------------------------------- # Loss # --------------------------------------------------------------------------- def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Gradient loss via Sobel filters. No VGG dependency. pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device.""" kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) B, C, H, W = pred.shape p = pred.view(B * C, 1, H, W) t = target.view(B * C, 1, H, W) return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) + F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1))) # --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- def train(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') enc_channels = [int(c) for c in args.enc_channels.split(',')] print(f"Device: {device}") if args.single_sample: args.full_image = True args.batch_size = 1 dataset = CNNv3Dataset( dataset_dir=args.input, input_mode=args.input_mode, patch_size=args.patch_size, patches_per_image=args.patches_per_image, image_size=args.image_size, full_image=args.full_image, channel_dropout_p=args.channel_dropout_p, detector=args.detector, augment=True, patch_search_window=args.patch_search_window, single_sample=args.single_sample, ) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device) nparams = sum(p.numel() for p in model.parameters()) print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first. film_warmup = args.film_warmup_epochs if film_warmup > 0: for p in model.film_mlp.parameters(): p.requires_grad = False print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)") optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) criterion = nn.MSELoss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 film_unfrozen = (film_warmup == 0) if args.resume: ckpt_path = Path(args.resume) if not ckpt_path.exists(): # Auto-find latest checkpoint in ckpt_dir ckpts = sorted(ckpt_dir.glob('checkpoint_epoch_*.pth'), key=lambda p: int(p.stem.split('_')[-1])) if not ckpts: raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}") ckpt_path = ckpts[-1] print(f"Resuming from {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) optimizer.load_state_dict(ckpt['optimizer_state_dict']) start_epoch = ckpt['epoch'] + 1 print(f" Resumed at epoch {start_epoch} (last loss {ckpt['loss']:.6f})") print(f"\nTraining epochs {start_epoch}–{args.epochs} batch={args.batch_size} lr={args.lr}") start = time.time() avg_loss = float('nan') epoch = start_epoch - 1 interrupted = False def _on_sigint(sig, frame): nonlocal interrupted interrupted = True signal.signal(signal.SIGINT, _on_sigint) try: for epoch in range(start_epoch, args.epochs + 1): if interrupted: break # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR. if not film_unfrozen and epoch > film_warmup: for p in model.film_mlp.parameters(): p.requires_grad = True optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1) film_unfrozen = True print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})") model.train() epoch_loss = 0.0 n_batches = 0 for feat, cond, target in loader: if interrupted: break feat, cond, target = feat.to(device), cond.to(device), target.to(device) optimizer.zero_grad() pred = model(feat, cond) loss = criterion(pred, target) if args.edge_loss_weight > 0.0: loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " f"{time.time()-start:.0f}s", end='', flush=True) if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: print() ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) print(f" → {ckpt}") finally: print() if epoch >= start_epoch: # at least one epoch completed final = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), final) if interrupted: print(f"Interrupted. Checkpoint saved: {final}") else: print(f"Final checkpoint: {final}") print(f"Done. {time.time()-start:.1f}s") return model def _checkpoint(model, optimizer, epoch, loss, args): return { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'config': { 'enc_channels': [int(c) for c in args.enc_channels.split(',')], 'film_cond_dim': args.film_cond_dim, 'input_mode': args.input_mode, 'edge_loss_weight': args.edge_loss_weight, 'film_warmup_epochs': args.film_warmup_epochs, }, } # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') # Dataset p.add_argument('--single-sample', default='', metavar='DIR', help='Train on a single sample directory; implies --full-image and --batch-size 1') p.add_argument('--input', default='training/dataset', help='Dataset root (contains full/ or simple/ subdirs)') p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], help='simple=photo samples full=Blender G-buffer samples') p.add_argument('--channel-dropout-p', type=float, default=0.3, help='Dropout prob for geometric channels (default 0.3)') # Patch / full-image mode p.add_argument('--full-image', action='store_true', help='Use full-image mode (resize to --image-size)') p.add_argument('--image-size', type=int, default=256, help='Full-image resize target (default 256)') p.add_argument('--patch-size', type=int, default=64, help='Patch size (default 64)') p.add_argument('--patches-per-image', type=int, default=256, help='Patches per image per epoch (default 256)') p.add_argument('--detector', default='harris', choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'], help='Salient point detector (default harris)') p.add_argument('--patch-search-window', type=int, default=0, help='Search ±N px in target to minimise grayscale MSE (default 0=disabled)') # Model p.add_argument('--enc-channels', default='4,8', help='Encoder channels, comma-separated (default 4,8)') p.add_argument('--film-cond-dim', type=int, default=5, help='FiLM conditioning input dim (default 5)') # Training p.add_argument('--epochs', type=int, default=200) p.add_argument('--batch-size', type=int, default=16) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--checkpoint-dir', default='checkpoints') p.add_argument('--checkpoint-every', type=int, default=50, help='Save checkpoint every N epochs (0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') p.add_argument('--edge-loss-weight', type=float, default=0.1, help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') p.add_argument('--film-warmup-epochs', type=int, default=50, help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') train(p.parse_args()) if __name__ == '__main__': main()