summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py222
1 files changed, 222 insertions, 0 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
new file mode 100644
index 0000000..ed925e6
--- /dev/null
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python3
+"""CNN v3 Training Script — U-Net + FiLM
+
+Architecture:
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ output sigmoid → RGBA
+
+FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
+ 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
+
+Weight budget: ~3.9 KB f16 (fits ≤6 KB target)
+"""
+
+import argparse
+import time
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from cnn_v3_utils import CNNv3Dataset, N_FEATURES
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
+ """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W."""
+ return gamma[:, :, None, None] * x + beta[:, :, None, None]
+
+
+class CNNv3(nn.Module):
+ """U-Net + FiLM conditioning.
+
+ enc_channels: [c0, c1] channel counts per encoder level, default [4, 8]
+ film_cond_dim: FiLM conditioning input size, default 5
+ """
+
+ def __init__(self, enc_channels=None, film_cond_dim: int = 5):
+ super().__init__()
+ if enc_channels is None:
+ enc_channels = [4, 8]
+ assert len(enc_channels) == 2, "Only 2-level U-Net supported"
+ c0, c1 = enc_channels
+
+ self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
+ self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
+ self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
+ self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
+
+ film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0
+ self.film_mlp = nn.Sequential(
+ nn.Linear(film_cond_dim, 16),
+ nn.ReLU(),
+ nn.Linear(16, film_out),
+ )
+ self.enc_channels = enc_channels
+
+ def _split_film(self, film: torch.Tensor):
+ c0, c1 = self.enc_channels
+ parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1)
+ return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0
+
+ def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]"""
+ g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond))
+
+ skip0 = F.relu(film_apply(self.enc0(feat), g0, b0))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(film_apply(self.enc1(x), g1, b1))
+
+ x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(film_apply(self.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ), gd1, bd1))
+
+ x = F.relu(film_apply(self.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ), gd0, bd0))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Training
+# ---------------------------------------------------------------------------
+
+def train(args):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ enc_channels = [int(c) for c in args.enc_channels.split(',')]
+ print(f"Device: {device}")
+
+ dataset = CNNv3Dataset(
+ dataset_dir=args.input,
+ input_mode=args.input_mode,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ image_size=args.image_size,
+ full_image=args.full_image,
+ channel_dropout_p=args.channel_dropout_p,
+ detector=args.detector,
+ augment=True,
+ )
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=0, drop_last=False)
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device)
+ nparams = sum(p.numel() for p in model.parameters())
+ print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
+ f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+ ckpt_dir = Path(args.checkpoint_dir)
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}")
+ start = time.time()
+ avg_loss = float('nan')
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+ n_batches = 0
+
+ for feat, cond, target in loader:
+ feat, cond, target = feat.to(device), cond.to(device), target.to(device)
+ optimizer.zero_grad()
+ loss = criterion(model(feat, cond), target)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.item()
+ n_batches += 1
+
+ avg_loss = epoch_loss / max(n_batches, 1)
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
+ f"{time.time()-start:.0f}s", end='', flush=True)
+
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print()
+ ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
+ print(f" → {ckpt}")
+
+ print()
+ final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth"
+ torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final)
+ print(f"Final checkpoint: {final}")
+ print(f"Done. {time.time()-start:.1f}s")
+ return model
+
+
+def _checkpoint(model, optimizer, epoch, loss, args):
+ return {
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': loss,
+ 'config': {
+ 'enc_channels': [int(c) for c in args.enc_channels.split(',')],
+ 'film_cond_dim': args.film_cond_dim,
+ 'input_mode': args.input_mode,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# CLI
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
+
+ # Dataset
+ p.add_argument('--input', default='training/dataset',
+ help='Dataset root (contains full/ or simple/ subdirs)')
+ p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
+ help='simple=photo samples full=Blender G-buffer samples')
+ p.add_argument('--channel-dropout-p', type=float, default=0.3,
+ help='Dropout prob for geometric channels (default 0.3)')
+
+ # Patch / full-image mode
+ p.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize to --image-size)')
+ p.add_argument('--image-size', type=int, default=256,
+ help='Full-image resize target (default 256)')
+ p.add_argument('--patch-size', type=int, default=64,
+ help='Patch size (default 64)')
+ p.add_argument('--patches-per-image', type=int, default=256,
+ help='Patches per image per epoch (default 256)')
+ p.add_argument('--detector', default='harris',
+ choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'],
+ help='Salient point detector (default harris)')
+
+ # Model
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--film-cond-dim', type=int, default=5,
+ help='FiLM conditioning input dim (default 5)')
+
+ # Training
+ p.add_argument('--epochs', type=int, default=200)
+ p.add_argument('--batch-size', type=int, default=16)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--checkpoint-dir', default='checkpoints')
+ p.add_argument('--checkpoint-every', type=int, default=50,
+ help='Save checkpoint every N epochs (0=disable)')
+
+ train(p.parse_args())
+
+
+if __name__ == '__main__':
+ main()