From 1e8ccfc67c264ce054c59257ee7c17ec4a584a9e Mon Sep 17 00:00:00 2001 From: skal Date: Sat, 21 Mar 2026 10:07:02 +0100 Subject: feat(cnn_v3): Phase 6 — training script (train_cnn_v3.py + cnn_v3_utils.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - train_cnn_v3.py: CNNv3 U-Net+FiLM model, training loop, CLI - cnn_v3_utils.py: image I/O, pyrdown, depth_gradient, assemble_features, apply_channel_dropout, detect_salient_points, CNNv3Dataset - Patch-based training (default 64×64) with salient-point extraction (harris/shi-tomasi/fast/gradient/random detectors, pre-cached at init) - Channel dropout for geometric/context/temporal channels - Random FiLM conditioning per sample for joint MLP+U-Net training - docs: HOWTO.md §3 updated with commands and flag reference - TODO.md: Phase 6 marked done, export script noted as next step Co-Authored-By: Claude Sonnet 4.6 --- cnn_v3/docs/HOWTO.md | 70 +++++++-- cnn_v3/training/cnn_v3_utils.py | 339 ++++++++++++++++++++++++++++++++++++++++ cnn_v3/training/train_cnn_v3.py | 222 ++++++++++++++++++++++++++ 3 files changed, 620 insertions(+), 11 deletions(-) create mode 100644 cnn_v3/training/cnn_v3_utils.py create mode 100644 cnn_v3/training/train_cnn_v3.py (limited to 'cnn_v3') diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index 425a33b..0cf2fe5 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -135,20 +135,68 @@ Mix freely; the dataloader treats all sample directories uniformly. ## 3. Training -*(Script not yet written — see TODO.md. Architecture spec in `CNN_V3.md` §Training.)* +Two source files: +- **`cnn_v3_utils.py`** — image I/O, feature assembly, channel dropout, salient-point + detection, `CNNv3Dataset` +- **`train_cnn_v3.py`** — `CNNv3` model, training loop, CLI + +### Quick start -**Planned command:** ```bash -python3 cnn_v3/training/train_cnn_v3.py \ - --dataset dataset/ \ - --epochs 500 \ - --output cnn_v3/weights/cnn_v3_weights.bin +cd cnn_v3/training + +# Patch-based (default) — 64×64 patches around Harris corners +python3 train_cnn_v3.py \ + --input dataset/ \ + --input-mode simple \ + --epochs 200 + +# Full-image mode (resizes to 256×256) +python3 train_cnn_v3.py \ + --input dataset/ \ + --input-mode full \ + --full-image --image-size 256 \ + --epochs 500 + +# Quick smoke test: 1 epoch, small patches, random detector +python3 train_cnn_v3.py \ + --input dataset/ --epochs 1 \ + --patch-size 32 --detector random ``` -**FiLM conditioning** during training: -- Beat/audio inputs randomized per sample -- MLP: `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net -- Output: γ/β for enc0(4ch) + enc1(8ch) + dec1(4ch) + dec0(4ch) = 40 floats +### Key flags + +| Flag | Default | Notes | +|------|---------|-------| +| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs | +| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer | +| `--patch-size N` | `64` | Patch crop size | +| `--patches-per-image N` | `256` | Patches extracted per image per epoch | +| `--detector` | `harris` | `harris` \| `shi-tomasi` \| `fast` \| `gradient` \| `random` | +| `--channel-dropout-p F` | `0.3` | Dropout prob for geometric channels | +| `--full-image` | off | Resize full image instead of cropping patches | +| `--enc-channels C` | `4,8` | Encoder channel counts, comma-separated | +| `--film-cond-dim N` | `5` | FiLM conditioning input size | +| `--epochs N` | `200` | Training epochs | +| `--batch-size N` | `16` | Batch size | +| `--lr F` | `1e-3` | Adam learning rate | +| `--checkpoint-dir DIR` | `checkpoints/` | Where to save `.pth` files | +| `--checkpoint-every N` | `50` | Epoch interval for checkpoints (0=disable) | + +### FiLM conditioning during training + +- Conditioning vector `[beat_phase, beat_time/8, audio_intensity, style_p0, style_p1]` + is **randomised per sample** (uniform [0,1]) so the MLP trains jointly with the U-Net. +- At inference, real beat/audio values are fed from `CNNv3Effect::set_film_params()`. + +### Channel dropout + +Applied per-sample in `cnn_v3_utils.apply_channel_dropout()`: +- Geometric channels (normal, depth, depth_grad) zeroed with `p=channel_dropout_p` +- Context channels (mat_id, shadow, transp) with `p≈0.2` +- Temporal channels (prev.rgb) with `p=0.5` + +This ensures the network works for both full G-buffer and photo-only inputs. --- @@ -202,7 +250,7 @@ Test vectors generated by `cnn_v3/training/gen_test_vectors.py` (PyTorch referen | 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet | | 4 — C++ CNNv3Effect | ✅ Done | FiLM uniform upload, 36/36 tests pass | | 5 — Parity validation | ✅ Done | test_cnn_v3_parity.cc, max_err=4.88e-4 | -| 6 — FiLM MLP training | TODO | train_cnn_v3.py not yet written | +| 6 — FiLM MLP training | ✅ Done | train_cnn_v3.py + cnn_v3_utils.py written | --- diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py new file mode 100644 index 0000000..ecdbd6b --- /dev/null +++ b/cnn_v3/training/cnn_v3_utils.py @@ -0,0 +1,339 @@ +"""CNN v3 training utilities — image I/O, feature assembly, dataset. + +Imported by train_cnn_v3.py and export_cnn_v3_weights.py. + +20 feature channels assembled by CNNv3Dataset.__getitem__: + [0-2] albedo.rgb f32 [0,1] + [3-4] normal.xy f32 oct-encoded [0,1] + [5] depth f32 [0,1] + [6-7] depth_grad.xy finite diff, signed + [8] mat_id f32 [0,1] + [9-11] prev.rgb f32 (zero during training) + [12-14] mip1.rgb pyrdown(albedo) + [15-17] mip2.rgb pyrdown(mip1) + [18] shadow f32 [0,1] + [19] transp f32 [0,1] + +Sample directory layout (per sample_xxx/): + albedo.png RGB uint8 + normal.png RG uint8 oct-encoded (128,128 = no normal) + depth.png R uint16 [0,65535] → [0,1] + matid.png R uint8 [0,255] → [0,1] + shadow.png R uint8 [0=dark, 255=lit] + transp.png R uint8 [0=opaque, 255=clear] + target.png RGBA uint8 +""" + +import random +from pathlib import Path +from typing import List, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +N_FEATURES = 20 + +GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy +CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp +TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb + +# --------------------------------------------------------------------------- +# Image I/O +# --------------------------------------------------------------------------- + +def load_rgb(path: Path) -> np.ndarray: + """Load PNG → (H,W,3) f32 [0,1].""" + return np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 + + +def load_rg(path: Path) -> np.ndarray: + """Load PNG RG channels → (H,W,2) f32 [0,1].""" + arr = np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0 + return arr[..., :2] + + +def load_depth16(path: Path) -> np.ndarray: + """Load 16-bit greyscale depth PNG → (H,W) f32 [0,1].""" + arr = np.asarray(Image.open(path), dtype=np.float32) + if arr.max() > 1.0: + arr = arr / 65535.0 + return arr + + +def load_gray(path: Path) -> np.ndarray: + """Load 8-bit greyscale PNG → (H,W) f32 [0,1].""" + return np.asarray(Image.open(path).convert('L'), dtype=np.float32) / 255.0 + + +# --------------------------------------------------------------------------- +# Feature assembly helpers +# --------------------------------------------------------------------------- + +def pyrdown(img: np.ndarray) -> np.ndarray: + """2×2 average pool → half resolution. img: (H,W,C) f32.""" + h, w, _ = img.shape + h2, w2 = h // 2, w // 2 + t = img[:h2 * 2, :w2 * 2, :] + return 0.25 * (t[0::2, 0::2] + t[1::2, 0::2] + t[0::2, 1::2] + t[1::2, 1::2]) + + +def depth_gradient(depth: np.ndarray) -> np.ndarray: + """Central finite difference of depth map → (H,W,2) [dzdx, dzdy].""" + px = np.pad(depth, ((0, 0), (1, 1)), mode='edge') + py = np.pad(depth, ((1, 1), (0, 0)), mode='edge') + dzdx = (px[:, 2:] - px[:, :-2]) * 0.5 + dzdy = (py[2:, :] - py[:-2, :]) * 0.5 + return np.stack([dzdx, dzdy], axis=-1) + + +def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: + """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C).""" + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) + img = img.resize((w, h), Image.NEAREST) + return np.asarray(img, dtype=np.float32) / 255.0 + + +def assemble_features(albedo: np.ndarray, normal: np.ndarray, + depth: np.ndarray, matid: np.ndarray, + shadow: np.ndarray, transp: np.ndarray) -> np.ndarray: + """Build (H,W,20) f32 feature tensor. + + prev set to zero (no temporal history during training). + mip1/mip2 computed from albedo. depth_grad computed via finite diff. + """ + h, w = albedo.shape[:2] + + mip1 = _upsample_nearest(pyrdown(albedo), h, w) + mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w) + dgrad = depth_gradient(depth) + prev = np.zeros((h, w, 3), dtype=np.float32) + + return np.concatenate([ + albedo, # [0-2] albedo.rgb + normal, # [3-4] normal.xy + depth[..., None], # [5] depth + dgrad, # [6-7] depth_grad.xy + matid[..., None], # [8] mat_id + prev, # [9-11] prev.rgb + mip1, # [12-14] mip1.rgb + mip2, # [15-17] mip2.rgb + shadow[..., None], # [18] shadow + transp[..., None], # [19] transp + ], axis=-1).astype(np.float32) + + +def apply_channel_dropout(feat: np.ndarray, + p_geom: float = 0.3, + p_context: float = 0.2, + p_temporal: float = 0.5) -> np.ndarray: + """Zero out channel groups with given probabilities (in-place copy).""" + feat = feat.copy() + if random.random() < p_geom: + feat[..., GEOMETRIC_CHANNELS] = 0.0 + if random.random() < p_context: + feat[..., CONTEXT_CHANNELS] = 0.0 + if random.random() < p_temporal: + feat[..., TEMPORAL_CHANNELS] = 0.0 + return feat + + +# --------------------------------------------------------------------------- +# Salient point detection +# --------------------------------------------------------------------------- + +def detect_salient_points(albedo: np.ndarray, n: int, detector: str, + patch_size: int) -> List[Tuple[int, int]]: + """Return n (cx, cy) patch centres from albedo (H,W,3) f32 [0,1]. + + Detects up to 2n candidates via the chosen method, filters to valid patch + bounds, fills remainder with random points. + + detector: 'harris' | 'shi-tomasi' | 'fast' | 'gradient' | 'random' + """ + h, w = albedo.shape[:2] + half = patch_size // 2 + gray = cv2.cvtColor((albedo * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + + corners = None + if detector in ('harris', 'shi-tomasi'): + corners = cv2.goodFeaturesToTrack( + gray, n * 2, qualityLevel=0.01, minDistance=half, + useHarrisDetector=(detector == 'harris')) + elif detector == 'fast': + kps = cv2.FastFeatureDetector_create(threshold=20).detect(gray, None) + if kps: + pts = np.array([[kp.pt[0], kp.pt[1]] for kp in kps[:n * 2]]) + corners = pts.reshape(-1, 1, 2) + elif detector == 'gradient': + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + mag = np.sqrt(gx ** 2 + gy ** 2) + ys, xs = np.where(mag > np.percentile(mag, 95)) + if len(xs) > n * 2: + sel = np.random.choice(len(xs), n * 2, replace=False) + xs, ys = xs[sel], ys[sel] + if len(xs): + corners = np.stack([xs, ys], axis=1).reshape(-1, 1, 2).astype(np.float32) + # 'random' → corners stays None, falls through to random fill + + pts: List[Tuple[int, int]] = [] + if corners is not None: + for c in corners: + cx, cy = int(c[0][0]), int(c[0][1]) + if half <= cx < w - half and half <= cy < h - half: + pts.append((cx, cy)) + if len(pts) >= n: + break + + while len(pts) < n: + pts.append((random.randint(half, w - half - 1), + random.randint(half, h - half - 1))) + return pts[:n] + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +class CNNv3Dataset(Dataset): + """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories. + + Patch mode (default): extracts patch_size×patch_size crops centred on + salient points detected from albedo. Points are pre-cached at init. + + Full-image mode (--full-image): resizes entire image to image_size×image_size. + + Returns (feat, cond, target): + feat: (20, H, W) f32 + cond: (5,) f32 FiLM conditioning (random when augment=True) + target: (4, H, W) f32 RGBA [0,1] + """ + + def __init__(self, dataset_dir: str, + input_mode: str = 'simple', + patch_size: int = 64, + patches_per_image: int = 256, + image_size: int = 256, + full_image: bool = False, + channel_dropout_p: float = 0.3, + detector: str = 'harris', + augment: bool = True): + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.image_size = image_size + self.full_image = full_image + self.channel_dropout_p = channel_dropout_p + self.detector = detector + self.augment = augment + + root = Path(dataset_dir) + subdir = 'full' if input_mode == 'full' else 'simple' + search_dir = root / subdir + if not search_dir.exists(): + search_dir = root + + self.samples = sorted([ + d for d in search_dir.iterdir() + if d.is_dir() and (d / 'albedo.png').exists() + ]) + if not self.samples: + raise RuntimeError(f"No samples found in {search_dir}") + + # Pre-cache salient patch centres (albedo-only load — cheap) + self._patch_centers: List[List[Tuple[int, int]]] = [] + if not full_image: + print(f"[CNNv3Dataset] Detecting salient points " + f"(detector={detector}, patch={patch_size}×{patch_size}) …") + for sd in self.samples: + pts = detect_salient_points( + load_rgb(sd / 'albedo.png'), + patches_per_image, detector, patch_size) + self._patch_centers.append(pts) + + print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} " + f"patch={patch_size} full_image={full_image}") + + def __len__(self): + if self.full_image: + return len(self.samples) + return len(self.samples) * self.patches_per_image + + def _load_sample(self, sd: Path): + albedo = load_rgb(sd / 'albedo.png') + normal = load_rg(sd / 'normal.png') + depth = load_depth16(sd / 'depth.png') + matid = load_gray(sd / 'matid.png') + shadow = load_gray(sd / 'shadow.png') + transp = load_gray(sd / 'transp.png') + target = np.asarray( + Image.open(sd / 'target.png').convert('RGBA'), + dtype=np.float32) / 255.0 + return albedo, normal, depth, matid, shadow, transp, target + + def __getitem__(self, idx): + if self.full_image: + sample_idx = idx + sd = self.samples[idx] + else: + sample_idx = idx // self.patches_per_image + sd = self.samples[sample_idx] + + albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd) + h, w = albedo.shape[:2] + + if self.full_image: + sz = self.image_size + + def _resize_rgb(a): + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) + return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 + + def _resize_gray(a): + img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L') + return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 + + albedo = _resize_rgb(albedo) + normal = _resize_rgb(np.concatenate( + [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] + depth = _resize_gray(depth) + matid = _resize_gray(matid) + shadow = _resize_gray(shadow) + transp = _resize_gray(transp) + target = _resize_rgb(target) + else: + ps = self.patch_size + half = ps // 2 + cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image] + cx = max(half, min(cx, w - half)) + cy = max(half, min(cy, h - half)) + sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps)) + + albedo = albedo[sl] + normal = normal[sl] + depth = depth[sl] + matid = matid[sl] + shadow = shadow[sl] + transp = transp[sl] + target = target[sl] + + feat = assemble_features(albedo, normal, depth, matid, shadow, transp) + + if self.augment: + feat = apply_channel_dropout(feat, + p_geom=self.channel_dropout_p, + p_context=self.channel_dropout_p * 0.67, + p_temporal=0.5) + cond = np.random.rand(5).astype(np.float32) + else: + cond = np.zeros(5, dtype=np.float32) + + return (torch.from_numpy(feat).permute(2, 0, 1), # (20,H,W) + torch.from_numpy(cond), # (5,) + torch.from_numpy(target).permute(2, 0, 1)) # (4,H,W) diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py new file mode 100644 index 0000000..ed925e6 --- /dev/null +++ b/cnn_v3/training/train_cnn_v3.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""CNN v3 Training Script — U-Net + FiLM + +Architecture: + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + output sigmoid → RGBA + +FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) + 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) + +Weight budget: ~3.9 KB f16 (fits ≤6 KB target) +""" + +import argparse +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from cnn_v3_utils import CNNv3Dataset, N_FEATURES + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + +def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W.""" + return gamma[:, :, None, None] * x + beta[:, :, None, None] + + +class CNNv3(nn.Module): + """U-Net + FiLM conditioning. + + enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] + film_cond_dim: FiLM conditioning input size, default 5 + """ + + def __init__(self, enc_channels=None, film_cond_dim: int = 5): + super().__init__() + if enc_channels is None: + enc_channels = [4, 8] + assert len(enc_channels) == 2, "Only 2-level U-Net supported" + c0, c1 = enc_channels + + self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) + self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) + self.bottleneck = nn.Conv2d(c1, c1, 1) + self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 + self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 + + film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0 + self.film_mlp = nn.Sequential( + nn.Linear(film_cond_dim, 16), + nn.ReLU(), + nn.Linear(16, film_out), + ) + self.enc_channels = enc_channels + + def _split_film(self, film: torch.Tensor): + c0, c1 = self.enc_channels + parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1) + return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0 + + def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]""" + g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond)) + + skip0 = F.relu(film_apply(self.enc0(feat), g0, b0)) + + x = F.avg_pool2d(skip0, 2) + skip1 = F.relu(film_apply(self.enc1(x), g1, b1)) + + x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2))) + + x = F.relu(film_apply(self.dec1( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1) + ), gd1, bd1)) + + x = F.relu(film_apply(self.dec0( + torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1) + ), gd0, bd0)) + + return torch.sigmoid(x) + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- + +def train(args): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + enc_channels = [int(c) for c in args.enc_channels.split(',')] + print(f"Device: {device}") + + dataset = CNNv3Dataset( + dataset_dir=args.input, + input_mode=args.input_mode, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + image_size=args.image_size, + full_image=args.full_image, + channel_dropout_p=args.channel_dropout_p, + detector=args.detector, + augment=True, + ) + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, + num_workers=0, drop_last=False) + + model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device) + nparams = sum(p.numel() for p in model.parameters()) + print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " + f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + criterion = nn.MSELoss() + ckpt_dir = Path(args.checkpoint_dir) + ckpt_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}") + start = time.time() + avg_loss = float('nan') + + for epoch in range(1, args.epochs + 1): + model.train() + epoch_loss = 0.0 + n_batches = 0 + + for feat, cond, target in loader: + feat, cond, target = feat.to(device), cond.to(device), target.to(device) + optimizer.zero_grad() + loss = criterion(model(feat, cond), target) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + n_batches += 1 + + avg_loss = epoch_loss / max(n_batches, 1) + print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | " + f"{time.time()-start:.0f}s", end='', flush=True) + + if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0: + print() + ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth" + torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt) + print(f" → {ckpt}") + + print() + final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth" + torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final) + print(f"Final checkpoint: {final}") + print(f"Done. {time.time()-start:.1f}s") + return model + + +def _checkpoint(model, optimizer, epoch, loss, args): + return { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': { + 'enc_channels': [int(c) for c in args.enc_channels.split(',')], + 'film_cond_dim': args.film_cond_dim, + 'input_mode': args.input_mode, + }, + } + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)') + + # Dataset + p.add_argument('--input', default='training/dataset', + help='Dataset root (contains full/ or simple/ subdirs)') + p.add_argument('--input-mode', default='simple', choices=['simple', 'full'], + help='simple=photo samples full=Blender G-buffer samples') + p.add_argument('--channel-dropout-p', type=float, default=0.3, + help='Dropout prob for geometric channels (default 0.3)') + + # Patch / full-image mode + p.add_argument('--full-image', action='store_true', + help='Use full-image mode (resize to --image-size)') + p.add_argument('--image-size', type=int, default=256, + help='Full-image resize target (default 256)') + p.add_argument('--patch-size', type=int, default=64, + help='Patch size (default 64)') + p.add_argument('--patches-per-image', type=int, default=256, + help='Patches per image per epoch (default 256)') + p.add_argument('--detector', default='harris', + choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'], + help='Salient point detector (default harris)') + + # Model + p.add_argument('--enc-channels', default='4,8', + help='Encoder channels, comma-separated (default 4,8)') + p.add_argument('--film-cond-dim', type=int, default=5, + help='FiLM conditioning input dim (default 5)') + + # Training + p.add_argument('--epochs', type=int, default=200) + p.add_argument('--batch-size', type=int, default=16) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--checkpoint-dir', default='checkpoints') + p.add_argument('--checkpoint-every', type=int, default=50, + help='Save checkpoint every N epochs (0=disable)') + + train(p.parse_args()) + + +if __name__ == '__main__': + main() -- cgit v1.2.3