summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py339
-rw-r--r--cnn_v3/training/train_cnn_v3.py222
2 files changed, 561 insertions, 0 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
new file mode 100644
index 0000000..ecdbd6b
--- /dev/null
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -0,0 +1,339 @@
+"""CNN v3 training utilities — image I/O, feature assembly, dataset.
+
+Imported by train_cnn_v3.py and export_cnn_v3_weights.py.
+
+20 feature channels assembled by CNNv3Dataset.__getitem__:
+ [0-2] albedo.rgb f32 [0,1]
+ [3-4] normal.xy f32 oct-encoded [0,1]
+ [5] depth f32 [0,1]
+ [6-7] depth_grad.xy finite diff, signed
+ [8] mat_id f32 [0,1]
+ [9-11] prev.rgb f32 (zero during training)
+ [12-14] mip1.rgb pyrdown(albedo)
+ [15-17] mip2.rgb pyrdown(mip1)
+ [18] shadow f32 [0,1]
+ [19] transp f32 [0,1]
+
+Sample directory layout (per sample_xxx/):
+ albedo.png RGB uint8
+ normal.png RG uint8 oct-encoded (128,128 = no normal)
+ depth.png R uint16 [0,65535] → [0,1]
+ matid.png R uint8 [0,255] → [0,1]
+ shadow.png R uint8 [0=dark, 255=lit]
+ transp.png R uint8 [0=opaque, 255=clear]
+ target.png RGBA uint8
+"""
+
+import random
+from pathlib import Path
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+N_FEATURES = 20
+
+GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy
+CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp
+TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb
+
+# ---------------------------------------------------------------------------
+# Image I/O
+# ---------------------------------------------------------------------------
+
+def load_rgb(path: Path) -> np.ndarray:
+ """Load PNG → (H,W,3) f32 [0,1]."""
+ return np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0
+
+
+def load_rg(path: Path) -> np.ndarray:
+ """Load PNG RG channels → (H,W,2) f32 [0,1]."""
+ arr = np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0
+ return arr[..., :2]
+
+
+def load_depth16(path: Path) -> np.ndarray:
+ """Load 16-bit greyscale depth PNG → (H,W) f32 [0,1]."""
+ arr = np.asarray(Image.open(path), dtype=np.float32)
+ if arr.max() > 1.0:
+ arr = arr / 65535.0
+ return arr
+
+
+def load_gray(path: Path) -> np.ndarray:
+ """Load 8-bit greyscale PNG → (H,W) f32 [0,1]."""
+ return np.asarray(Image.open(path).convert('L'), dtype=np.float32) / 255.0
+
+
+# ---------------------------------------------------------------------------
+# Feature assembly helpers
+# ---------------------------------------------------------------------------
+
+def pyrdown(img: np.ndarray) -> np.ndarray:
+ """2×2 average pool → half resolution. img: (H,W,C) f32."""
+ h, w, _ = img.shape
+ h2, w2 = h // 2, w // 2
+ t = img[:h2 * 2, :w2 * 2, :]
+ return 0.25 * (t[0::2, 0::2] + t[1::2, 0::2] + t[0::2, 1::2] + t[1::2, 1::2])
+
+
+def depth_gradient(depth: np.ndarray) -> np.ndarray:
+ """Central finite difference of depth map → (H,W,2) [dzdx, dzdy]."""
+ px = np.pad(depth, ((0, 0), (1, 1)), mode='edge')
+ py = np.pad(depth, ((1, 1), (0, 0)), mode='edge')
+ dzdx = (px[:, 2:] - px[:, :-2]) * 0.5
+ dzdy = (py[2:, :] - py[:-2, :]) * 0.5
+ return np.stack([dzdx, dzdy], axis=-1)
+
+
+def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
+ """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C)."""
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
+ img = img.resize((w, h), Image.NEAREST)
+ return np.asarray(img, dtype=np.float32) / 255.0
+
+
+def assemble_features(albedo: np.ndarray, normal: np.ndarray,
+ depth: np.ndarray, matid: np.ndarray,
+ shadow: np.ndarray, transp: np.ndarray) -> np.ndarray:
+ """Build (H,W,20) f32 feature tensor.
+
+ prev set to zero (no temporal history during training).
+ mip1/mip2 computed from albedo. depth_grad computed via finite diff.
+ """
+ h, w = albedo.shape[:2]
+
+ mip1 = _upsample_nearest(pyrdown(albedo), h, w)
+ mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
+ dgrad = depth_gradient(depth)
+ prev = np.zeros((h, w, 3), dtype=np.float32)
+
+ return np.concatenate([
+ albedo, # [0-2] albedo.rgb
+ normal, # [3-4] normal.xy
+ depth[..., None], # [5] depth
+ dgrad, # [6-7] depth_grad.xy
+ matid[..., None], # [8] mat_id
+ prev, # [9-11] prev.rgb
+ mip1, # [12-14] mip1.rgb
+ mip2, # [15-17] mip2.rgb
+ shadow[..., None], # [18] shadow
+ transp[..., None], # [19] transp
+ ], axis=-1).astype(np.float32)
+
+
+def apply_channel_dropout(feat: np.ndarray,
+ p_geom: float = 0.3,
+ p_context: float = 0.2,
+ p_temporal: float = 0.5) -> np.ndarray:
+ """Zero out channel groups with given probabilities (in-place copy)."""
+ feat = feat.copy()
+ if random.random() < p_geom:
+ feat[..., GEOMETRIC_CHANNELS] = 0.0
+ if random.random() < p_context:
+ feat[..., CONTEXT_CHANNELS] = 0.0
+ if random.random() < p_temporal:
+ feat[..., TEMPORAL_CHANNELS] = 0.0
+ return feat
+
+
+# ---------------------------------------------------------------------------
+# Salient point detection
+# ---------------------------------------------------------------------------
+
+def detect_salient_points(albedo: np.ndarray, n: int, detector: str,
+ patch_size: int) -> List[Tuple[int, int]]:
+ """Return n (cx, cy) patch centres from albedo (H,W,3) f32 [0,1].
+
+ Detects up to 2n candidates via the chosen method, filters to valid patch
+ bounds, fills remainder with random points.
+
+ detector: 'harris' | 'shi-tomasi' | 'fast' | 'gradient' | 'random'
+ """
+ h, w = albedo.shape[:2]
+ half = patch_size // 2
+ gray = cv2.cvtColor((albedo * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+
+ corners = None
+ if detector in ('harris', 'shi-tomasi'):
+ corners = cv2.goodFeaturesToTrack(
+ gray, n * 2, qualityLevel=0.01, minDistance=half,
+ useHarrisDetector=(detector == 'harris'))
+ elif detector == 'fast':
+ kps = cv2.FastFeatureDetector_create(threshold=20).detect(gray, None)
+ if kps:
+ pts = np.array([[kp.pt[0], kp.pt[1]] for kp in kps[:n * 2]])
+ corners = pts.reshape(-1, 1, 2)
+ elif detector == 'gradient':
+ gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ mag = np.sqrt(gx ** 2 + gy ** 2)
+ ys, xs = np.where(mag > np.percentile(mag, 95))
+ if len(xs) > n * 2:
+ sel = np.random.choice(len(xs), n * 2, replace=False)
+ xs, ys = xs[sel], ys[sel]
+ if len(xs):
+ corners = np.stack([xs, ys], axis=1).reshape(-1, 1, 2).astype(np.float32)
+ # 'random' → corners stays None, falls through to random fill
+
+ pts: List[Tuple[int, int]] = []
+ if corners is not None:
+ for c in corners:
+ cx, cy = int(c[0][0]), int(c[0][1])
+ if half <= cx < w - half and half <= cy < h - half:
+ pts.append((cx, cy))
+ if len(pts) >= n:
+ break
+
+ while len(pts) < n:
+ pts.append((random.randint(half, w - half - 1),
+ random.randint(half, h - half - 1)))
+ return pts[:n]
+
+
+# ---------------------------------------------------------------------------
+# Dataset
+# ---------------------------------------------------------------------------
+
+class CNNv3Dataset(Dataset):
+ """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories.
+
+ Patch mode (default): extracts patch_size×patch_size crops centred on
+ salient points detected from albedo. Points are pre-cached at init.
+
+ Full-image mode (--full-image): resizes entire image to image_size×image_size.
+
+ Returns (feat, cond, target):
+ feat: (20, H, W) f32
+ cond: (5,) f32 FiLM conditioning (random when augment=True)
+ target: (4, H, W) f32 RGBA [0,1]
+ """
+
+ def __init__(self, dataset_dir: str,
+ input_mode: str = 'simple',
+ patch_size: int = 64,
+ patches_per_image: int = 256,
+ image_size: int = 256,
+ full_image: bool = False,
+ channel_dropout_p: float = 0.3,
+ detector: str = 'harris',
+ augment: bool = True):
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.image_size = image_size
+ self.full_image = full_image
+ self.channel_dropout_p = channel_dropout_p
+ self.detector = detector
+ self.augment = augment
+
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
+ if not self.samples:
+ raise RuntimeError(f"No samples found in {search_dir}")
+
+ # Pre-cache salient patch centres (albedo-only load — cheap)
+ self._patch_centers: List[List[Tuple[int, int]]] = []
+ if not full_image:
+ print(f"[CNNv3Dataset] Detecting salient points "
+ f"(detector={detector}, patch={patch_size}×{patch_size}) …")
+ for sd in self.samples:
+ pts = detect_salient_points(
+ load_rgb(sd / 'albedo.png'),
+ patches_per_image, detector, patch_size)
+ self._patch_centers.append(pts)
+
+ print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
+ f"patch={patch_size} full_image={full_image}")
+
+ def __len__(self):
+ if self.full_image:
+ return len(self.samples)
+ return len(self.samples) * self.patches_per_image
+
+ def _load_sample(self, sd: Path):
+ albedo = load_rgb(sd / 'albedo.png')
+ normal = load_rg(sd / 'normal.png')
+ depth = load_depth16(sd / 'depth.png')
+ matid = load_gray(sd / 'matid.png')
+ shadow = load_gray(sd / 'shadow.png')
+ transp = load_gray(sd / 'transp.png')
+ target = np.asarray(
+ Image.open(sd / 'target.png').convert('RGBA'),
+ dtype=np.float32) / 255.0
+ return albedo, normal, depth, matid, shadow, transp, target
+
+ def __getitem__(self, idx):
+ if self.full_image:
+ sample_idx = idx
+ sd = self.samples[idx]
+ else:
+ sample_idx = idx // self.patches_per_image
+ sd = self.samples[sample_idx]
+
+ albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd)
+ h, w = albedo.shape[:2]
+
+ if self.full_image:
+ sz = self.image_size
+
+ def _resize_rgb(a):
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
+ return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
+
+ def _resize_gray(a):
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L')
+ return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
+
+ albedo = _resize_rgb(albedo)
+ normal = _resize_rgb(np.concatenate(
+ [normal, np.zeros_like(normal[..., :1])], -1))[..., :2]
+ depth = _resize_gray(depth)
+ matid = _resize_gray(matid)
+ shadow = _resize_gray(shadow)
+ transp = _resize_gray(transp)
+ target = _resize_rgb(target)
+ else:
+ ps = self.patch_size
+ half = ps // 2
+ cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image]
+ cx = max(half, min(cx, w - half))
+ cy = max(half, min(cy, h - half))
+ sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps))
+
+ albedo = albedo[sl]
+ normal = normal[sl]
+ depth = depth[sl]
+ matid = matid[sl]
+ shadow = shadow[sl]
+ transp = transp[sl]
+ target = target[sl]
+
+ feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
+
+ if self.augment:
+ feat = apply_channel_dropout(feat,
+ p_geom=self.channel_dropout_p,
+ p_context=self.channel_dropout_p * 0.67,
+ p_temporal=0.5)
+ cond = np.random.rand(5).astype(np.float32)
+ else:
+ cond = np.zeros(5, dtype=np.float32)
+
+ return (torch.from_numpy(feat).permute(2, 0, 1), # (20,H,W)
+ torch.from_numpy(cond), # (5,)
+ torch.from_numpy(target).permute(2, 0, 1)) # (4,H,W)
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
new file mode 100644
index 0000000..ed925e6
--- /dev/null
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python3
+"""CNN v3 Training Script — U-Net + FiLM
+
+Architecture:
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ output sigmoid → RGBA
+
+FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
+ 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
+
+Weight budget: ~3.9 KB f16 (fits ≤6 KB target)
+"""
+
+import argparse
+import time
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from cnn_v3_utils import CNNv3Dataset, N_FEATURES
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
+ """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W."""
+ return gamma[:, :, None, None] * x + beta[:, :, None, None]
+
+
+class CNNv3(nn.Module):
+ """U-Net + FiLM conditioning.
+
+ enc_channels: [c0, c1] channel counts per encoder level, default [4, 8]
+ film_cond_dim: FiLM conditioning input size, default 5
+ """
+
+ def __init__(self, enc_channels=None, film_cond_dim: int = 5):
+ super().__init__()
+ if enc_channels is None:
+ enc_channels = [4, 8]
+ assert len(enc_channels) == 2, "Only 2-level U-Net supported"
+ c0, c1 = enc_channels
+
+ self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
+ self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
+ self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
+ self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
+
+ film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0
+ self.film_mlp = nn.Sequential(
+ nn.Linear(film_cond_dim, 16),
+ nn.ReLU(),
+ nn.Linear(16, film_out),
+ )
+ self.enc_channels = enc_channels
+
+ def _split_film(self, film: torch.Tensor):
+ c0, c1 = self.enc_channels
+ parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1)
+ return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0
+
+ def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]"""
+ g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond))
+
+ skip0 = F.relu(film_apply(self.enc0(feat), g0, b0))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(film_apply(self.enc1(x), g1, b1))
+
+ x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(film_apply(self.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ), gd1, bd1))
+
+ x = F.relu(film_apply(self.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ), gd0, bd0))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Training
+# ---------------------------------------------------------------------------
+
+def train(args):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ enc_channels = [int(c) for c in args.enc_channels.split(',')]
+ print(f"Device: {device}")
+
+ dataset = CNNv3Dataset(
+ dataset_dir=args.input,
+ input_mode=args.input_mode,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ image_size=args.image_size,
+ full_image=args.full_image,
+ channel_dropout_p=args.channel_dropout_p,
+ detector=args.detector,
+ augment=True,
+ )
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=0, drop_last=False)
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device)
+ nparams = sum(p.numel() for p in model.parameters())
+ print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
+ f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+ ckpt_dir = Path(args.checkpoint_dir)
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}")
+ start = time.time()
+ avg_loss = float('nan')
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+ n_batches = 0
+
+ for feat, cond, target in loader:
+ feat, cond, target = feat.to(device), cond.to(device), target.to(device)
+ optimizer.zero_grad()
+ loss = criterion(model(feat, cond), target)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.item()
+ n_batches += 1
+
+ avg_loss = epoch_loss / max(n_batches, 1)
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
+ f"{time.time()-start:.0f}s", end='', flush=True)
+
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print()
+ ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
+ print(f" → {ckpt}")
+
+ print()
+ final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth"
+ torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final)
+ print(f"Final checkpoint: {final}")
+ print(f"Done. {time.time()-start:.1f}s")
+ return model
+
+
+def _checkpoint(model, optimizer, epoch, loss, args):
+ return {
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': loss,
+ 'config': {
+ 'enc_channels': [int(c) for c in args.enc_channels.split(',')],
+ 'film_cond_dim': args.film_cond_dim,
+ 'input_mode': args.input_mode,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# CLI
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
+
+ # Dataset
+ p.add_argument('--input', default='training/dataset',
+ help='Dataset root (contains full/ or simple/ subdirs)')
+ p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
+ help='simple=photo samples full=Blender G-buffer samples')
+ p.add_argument('--channel-dropout-p', type=float, default=0.3,
+ help='Dropout prob for geometric channels (default 0.3)')
+
+ # Patch / full-image mode
+ p.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize to --image-size)')
+ p.add_argument('--image-size', type=int, default=256,
+ help='Full-image resize target (default 256)')
+ p.add_argument('--patch-size', type=int, default=64,
+ help='Patch size (default 64)')
+ p.add_argument('--patches-per-image', type=int, default=256,
+ help='Patches per image per epoch (default 256)')
+ p.add_argument('--detector', default='harris',
+ choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'],
+ help='Salient point detector (default harris)')
+
+ # Model
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--film-cond-dim', type=int, default=5,
+ help='FiLM conditioning input dim (default 5)')
+
+ # Training
+ p.add_argument('--epochs', type=int, default=200)
+ p.add_argument('--batch-size', type=int, default=16)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--checkpoint-dir', default='checkpoints')
+ p.add_argument('--checkpoint-every', type=int, default=50,
+ help='Save checkpoint every N epochs (0=disable)')
+
+ train(p.parse_args())
+
+
+if __name__ == '__main__':
+ main()