diff options
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py new file mode 100644 index 0000000..ed925e6 --- /dev/null +++ b/cnn_v3/training/train_cnn_v3.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""CNN v3 Training Script — U-Net + FiLM + +Architecture: + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + output sigmoid → RGBA + +FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) + 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) + +Weight budget: ~3.9 KB f16 (fits ≤6 KB target) +""" + +import argparse +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from cnn_v3_utils import CNNv3Dataset, N_FEATURES + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + +def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W.""" + return gamma[:, :, None, None] * x + beta[:, :, None, None] + + +class CNNv3(nn.Module): + """U-Net + FiLM conditioning. + + enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] + film_cond_dim: FiLM conditioning input size, default 5 + """ + + def __init__(self, enc_channels=None, film_cond_dim: int = 5): + super().__init__() + if enc_channels is None: + enc_channels = [4, 8] + assert len(enc_channels) == 2, "Only 2-level U-Net supported" + c0, c1 = enc_channels + + self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) + self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) + self.bottleneck = nn.Conv2d(c1, c1, 1) + self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 + self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 + + film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0 + self.film_mlp = nn.Sequential( + nn.Linear(film_cond_dim, 16), + nn.ReLU(), + nn.Linear(16, film_out), + ) + self.enc_channels = enc_channels + + def _split_film(self, film: torch.Tensor): + c0, c1 = self.enc_channels + parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1) + return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0 + + def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]""" + g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond)) + + skip0 = F.relu(film_apply(self.enc0(feat), g0, b0)) + + x = F.avg_pool2d(skip0, 2) + skip1 = F.relu(film_apply(self.enc1(x), g1, b1)) + + x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2))) + + x = F.relu(film_apply(self.dec1( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) + ), gd1, bd1)) + + x = F.relu(film_apply(self.dec0( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) + ), gd0, bd0)) + + return torch.sigmoid(x) + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- + +def train(args): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + enc_channels = [int(c) for c in args.enc_channels.split(',')] + print(f"Device: {device}") + + dataset = CNNv3Dataset( + dataset_dir=args.input, + input_mode=args.input_mode, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + image_size=args.image_size, + full_image=args.full_image, + channel_dropout_p=args.channel_dropout_p, + detector=args.detector, + augment=True, + ) + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, + num_workers=0, drop_last=False) + + model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device) + nparams = sum(p.numel() for p in model.parameters()) + print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " + f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + ckpt_dir = Path(args.checkpoint_dir) + ckpt_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}") + start = time.time() + avg_loss = float('nan') + + for epoch in range(1, args.epochs + 1): + model.train() + epoch_loss = 0.0 + n_batches = 0 + + for feat, cond, target in loader: + feat, cond, target = feat.to(device), cond.to(device), target.to(device) + optimizer.zero_grad() + loss = criterion(model(feat, cond), target) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + n_batches += 1 + + avg_loss = epoch_loss / max(n_batches, 1) + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " + f"{time.time()-start:.0f}s", end='', flush=True) + + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() + ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) + print(f" → {ckpt}") + + print() + final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth" + torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final) + print(f"Final checkpoint: {final}") + print(f"Done. {time.time()-start:.1f}s") + return model + + +def _checkpoint(model, optimizer, epoch, loss, args): + return { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': { + 'enc_channels': [int(c) for c in args.enc_channels.split(',')], + 'film_cond_dim': args.film_cond_dim, + 'input_mode': args.input_mode, + }, + } + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') + + # Dataset + p.add_argument('--input', default='training/dataset', + help='Dataset root (contains full/ or simple/ subdirs)') + p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], + help='simple=photo samples full=Blender G-buffer samples') + p.add_argument('--channel-dropout-p', type=float, default=0.3, + help='Dropout prob for geometric channels (default 0.3)') + + # Patch / full-image mode + p.add_argument('--full-image', action='store_true', + help='Use full-image mode (resize to --image-size)') + p.add_argument('--image-size', type=int, default=256, + help='Full-image resize target (default 256)') + p.add_argument('--patch-size', type=int, default=64, + help='Patch size (default 64)') + p.add_argument('--patches-per-image', type=int, default=256, + help='Patches per image per epoch (default 256)') + p.add_argument('--detector', default='harris', + choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'], + help='Salient point detector (default harris)') + + # Model + p.add_argument('--enc-channels', default='4,8', + help='Encoder channels, comma-separated (default 4,8)') + p.add_argument('--film-cond-dim', type=int, default=5, + help='FiLM conditioning input dim (default 5)') + + # Training + p.add_argument('--epochs', type=int, default=200) + p.add_argument('--batch-size', type=int, default=16) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--checkpoint-dir', default='checkpoints') + p.add_argument('--checkpoint-every', type=int, default=50, + help='Save checkpoint every N epochs (0=disable)') + + train(p.parse_args()) + + +if __name__ == '__main__': + main() |
