summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 10:07:02 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 10:07:02 +0100
commit1e8ccfc67c264ce054c59257ee7c17ec4a584a9e (patch)
tree765c3e4392af87c86e9052c321c48a43fda0fac7 /cnn_v3
parent5e740fc8f5f48fdd8ec4b84ae0c9a3c74e387d4f (diff)
feat(cnn_v3): Phase 6 — training script (train_cnn_v3.py + cnn_v3_utils.py)
- train_cnn_v3.py: CNNv3 U-Net+FiLM model, training loop, CLI - cnn_v3_utils.py: image I/O, pyrdown, depth_gradient, assemble_features, apply_channel_dropout, detect_salient_points, CNNv3Dataset - Patch-based training (default 64×64) with salient-point extraction (harris/shi-tomasi/fast/gradient/random detectors, pre-cached at init) - Channel dropout for geometric/context/temporal channels - Random FiLM conditioning per sample for joint MLP+U-Net training - docs: HOWTO.md §3 updated with commands and flag reference - TODO.md: Phase 6 marked done, export script noted as next step Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/docs/HOWTO.md70
-rw-r--r--cnn_v3/training/cnn_v3_utils.py339
-rw-r--r--cnn_v3/training/train_cnn_v3.py222
3 files changed, 620 insertions, 11 deletions
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 425a33b..0cf2fe5 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -135,20 +135,68 @@ Mix freely; the dataloader treats all sample directories uniformly.
## 3. Training
-*(Script not yet written — see TODO.md. Architecture spec in `CNN_V3.md` §Training.)*
+Two source files:
+- **`cnn_v3_utils.py`** — image I/O, feature assembly, channel dropout, salient-point
+ detection, `CNNv3Dataset`
+- **`train_cnn_v3.py`** — `CNNv3` model, training loop, CLI
+
+### Quick start
-**Planned command:**
```bash
-python3 cnn_v3/training/train_cnn_v3.py \
- --dataset dataset/ \
- --epochs 500 \
- --output cnn_v3/weights/cnn_v3_weights.bin
+cd cnn_v3/training
+
+# Patch-based (default) — 64×64 patches around Harris corners
+python3 train_cnn_v3.py \
+ --input dataset/ \
+ --input-mode simple \
+ --epochs 200
+
+# Full-image mode (resizes to 256×256)
+python3 train_cnn_v3.py \
+ --input dataset/ \
+ --input-mode full \
+ --full-image --image-size 256 \
+ --epochs 500
+
+# Quick smoke test: 1 epoch, small patches, random detector
+python3 train_cnn_v3.py \
+ --input dataset/ --epochs 1 \
+ --patch-size 32 --detector random
```
-**FiLM conditioning** during training:
-- Beat/audio inputs randomized per sample
-- MLP: `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net
-- Output: γ/β for enc0(4ch) + enc1(8ch) + dec1(4ch) + dec0(4ch) = 40 floats
+### Key flags
+
+| Flag | Default | Notes |
+|------|---------|-------|
+| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs |
+| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
+| `--patch-size N` | `64` | Patch crop size |
+| `--patches-per-image N` | `256` | Patches extracted per image per epoch |
+| `--detector` | `harris` | `harris` \| `shi-tomasi` \| `fast` \| `gradient` \| `random` |
+| `--channel-dropout-p F` | `0.3` | Dropout prob for geometric channels |
+| `--full-image` | off | Resize full image instead of cropping patches |
+| `--enc-channels C` | `4,8` | Encoder channel counts, comma-separated |
+| `--film-cond-dim N` | `5` | FiLM conditioning input size |
+| `--epochs N` | `200` | Training epochs |
+| `--batch-size N` | `16` | Batch size |
+| `--lr F` | `1e-3` | Adam learning rate |
+| `--checkpoint-dir DIR` | `checkpoints/` | Where to save `.pth` files |
+| `--checkpoint-every N` | `50` | Epoch interval for checkpoints (0=disable) |
+
+### FiLM conditioning during training
+
+- Conditioning vector `[beat_phase, beat_time/8, audio_intensity, style_p0, style_p1]`
+ is **randomised per sample** (uniform [0,1]) so the MLP trains jointly with the U-Net.
+- At inference, real beat/audio values are fed from `CNNv3Effect::set_film_params()`.
+
+### Channel dropout
+
+Applied per-sample in `cnn_v3_utils.apply_channel_dropout()`:
+- Geometric channels (normal, depth, depth_grad) zeroed with `p=channel_dropout_p`
+- Context channels (mat_id, shadow, transp) with `p≈0.2`
+- Temporal channels (prev.rgb) with `p=0.5`
+
+This ensures the network works for both full G-buffer and photo-only inputs.
---
@@ -202,7 +250,7 @@ Test vectors generated by `cnn_v3/training/gen_test_vectors.py` (PyTorch referen
| 3 — WGSL U-Net shaders | ✅ Done | 5 compute shaders + cnn_v3/common snippet |
| 4 — C++ CNNv3Effect | ✅ Done | FiLM uniform upload, 36/36 tests pass |
| 5 — Parity validation | ✅ Done | test_cnn_v3_parity.cc, max_err=4.88e-4 |
-| 6 — FiLM MLP training | TODO | train_cnn_v3.py not yet written |
+| 6 — FiLM MLP training | ✅ Done | train_cnn_v3.py + cnn_v3_utils.py written |
---
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
new file mode 100644
index 0000000..ecdbd6b
--- /dev/null
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -0,0 +1,339 @@
+"""CNN v3 training utilities — image I/O, feature assembly, dataset.
+
+Imported by train_cnn_v3.py and export_cnn_v3_weights.py.
+
+20 feature channels assembled by CNNv3Dataset.__getitem__:
+ [0-2] albedo.rgb f32 [0,1]
+ [3-4] normal.xy f32 oct-encoded [0,1]
+ [5] depth f32 [0,1]
+ [6-7] depth_grad.xy finite diff, signed
+ [8] mat_id f32 [0,1]
+ [9-11] prev.rgb f32 (zero during training)
+ [12-14] mip1.rgb pyrdown(albedo)
+ [15-17] mip2.rgb pyrdown(mip1)
+ [18] shadow f32 [0,1]
+ [19] transp f32 [0,1]
+
+Sample directory layout (per sample_xxx/):
+ albedo.png RGB uint8
+ normal.png RG uint8 oct-encoded (128,128 = no normal)
+ depth.png R uint16 [0,65535] → [0,1]
+ matid.png R uint8 [0,255] → [0,1]
+ shadow.png R uint8 [0=dark, 255=lit]
+ transp.png R uint8 [0=opaque, 255=clear]
+ target.png RGBA uint8
+"""
+
+import random
+from pathlib import Path
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+N_FEATURES = 20
+
+GEOMETRIC_CHANNELS = [3, 4, 5, 6, 7] # normal.xy, depth, depth_grad.xy
+CONTEXT_CHANNELS = [8, 18, 19] # mat_id, shadow, transp
+TEMPORAL_CHANNELS = [9, 10, 11] # prev.rgb
+
+# ---------------------------------------------------------------------------
+# Image I/O
+# ---------------------------------------------------------------------------
+
+def load_rgb(path: Path) -> np.ndarray:
+ """Load PNG → (H,W,3) f32 [0,1]."""
+ return np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0
+
+
+def load_rg(path: Path) -> np.ndarray:
+ """Load PNG RG channels → (H,W,2) f32 [0,1]."""
+ arr = np.asarray(Image.open(path).convert('RGB'), dtype=np.float32) / 255.0
+ return arr[..., :2]
+
+
+def load_depth16(path: Path) -> np.ndarray:
+ """Load 16-bit greyscale depth PNG → (H,W) f32 [0,1]."""
+ arr = np.asarray(Image.open(path), dtype=np.float32)
+ if arr.max() > 1.0:
+ arr = arr / 65535.0
+ return arr
+
+
+def load_gray(path: Path) -> np.ndarray:
+ """Load 8-bit greyscale PNG → (H,W) f32 [0,1]."""
+ return np.asarray(Image.open(path).convert('L'), dtype=np.float32) / 255.0
+
+
+# ---------------------------------------------------------------------------
+# Feature assembly helpers
+# ---------------------------------------------------------------------------
+
+def pyrdown(img: np.ndarray) -> np.ndarray:
+ """2×2 average pool → half resolution. img: (H,W,C) f32."""
+ h, w, _ = img.shape
+ h2, w2 = h // 2, w // 2
+ t = img[:h2 * 2, :w2 * 2, :]
+ return 0.25 * (t[0::2, 0::2] + t[1::2, 0::2] + t[0::2, 1::2] + t[1::2, 1::2])
+
+
+def depth_gradient(depth: np.ndarray) -> np.ndarray:
+ """Central finite difference of depth map → (H,W,2) [dzdx, dzdy]."""
+ px = np.pad(depth, ((0, 0), (1, 1)), mode='edge')
+ py = np.pad(depth, ((1, 1), (0, 0)), mode='edge')
+ dzdx = (px[:, 2:] - px[:, :-2]) * 0.5
+ dzdy = (py[2:, :] - py[:-2, :]) * 0.5
+ return np.stack([dzdx, dzdy], axis=-1)
+
+
+def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
+ """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C)."""
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
+ img = img.resize((w, h), Image.NEAREST)
+ return np.asarray(img, dtype=np.float32) / 255.0
+
+
+def assemble_features(albedo: np.ndarray, normal: np.ndarray,
+ depth: np.ndarray, matid: np.ndarray,
+ shadow: np.ndarray, transp: np.ndarray) -> np.ndarray:
+ """Build (H,W,20) f32 feature tensor.
+
+ prev set to zero (no temporal history during training).
+ mip1/mip2 computed from albedo. depth_grad computed via finite diff.
+ """
+ h, w = albedo.shape[:2]
+
+ mip1 = _upsample_nearest(pyrdown(albedo), h, w)
+ mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
+ dgrad = depth_gradient(depth)
+ prev = np.zeros((h, w, 3), dtype=np.float32)
+
+ return np.concatenate([
+ albedo, # [0-2] albedo.rgb
+ normal, # [3-4] normal.xy
+ depth[..., None], # [5] depth
+ dgrad, # [6-7] depth_grad.xy
+ matid[..., None], # [8] mat_id
+ prev, # [9-11] prev.rgb
+ mip1, # [12-14] mip1.rgb
+ mip2, # [15-17] mip2.rgb
+ shadow[..., None], # [18] shadow
+ transp[..., None], # [19] transp
+ ], axis=-1).astype(np.float32)
+
+
+def apply_channel_dropout(feat: np.ndarray,
+ p_geom: float = 0.3,
+ p_context: float = 0.2,
+ p_temporal: float = 0.5) -> np.ndarray:
+ """Zero out channel groups with given probabilities (in-place copy)."""
+ feat = feat.copy()
+ if random.random() < p_geom:
+ feat[..., GEOMETRIC_CHANNELS] = 0.0
+ if random.random() < p_context:
+ feat[..., CONTEXT_CHANNELS] = 0.0
+ if random.random() < p_temporal:
+ feat[..., TEMPORAL_CHANNELS] = 0.0
+ return feat
+
+
+# ---------------------------------------------------------------------------
+# Salient point detection
+# ---------------------------------------------------------------------------
+
+def detect_salient_points(albedo: np.ndarray, n: int, detector: str,
+ patch_size: int) -> List[Tuple[int, int]]:
+ """Return n (cx, cy) patch centres from albedo (H,W,3) f32 [0,1].
+
+ Detects up to 2n candidates via the chosen method, filters to valid patch
+ bounds, fills remainder with random points.
+
+ detector: 'harris' | 'shi-tomasi' | 'fast' | 'gradient' | 'random'
+ """
+ h, w = albedo.shape[:2]
+ half = patch_size // 2
+ gray = cv2.cvtColor((albedo * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+
+ corners = None
+ if detector in ('harris', 'shi-tomasi'):
+ corners = cv2.goodFeaturesToTrack(
+ gray, n * 2, qualityLevel=0.01, minDistance=half,
+ useHarrisDetector=(detector == 'harris'))
+ elif detector == 'fast':
+ kps = cv2.FastFeatureDetector_create(threshold=20).detect(gray, None)
+ if kps:
+ pts = np.array([[kp.pt[0], kp.pt[1]] for kp in kps[:n * 2]])
+ corners = pts.reshape(-1, 1, 2)
+ elif detector == 'gradient':
+ gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ mag = np.sqrt(gx ** 2 + gy ** 2)
+ ys, xs = np.where(mag > np.percentile(mag, 95))
+ if len(xs) > n * 2:
+ sel = np.random.choice(len(xs), n * 2, replace=False)
+ xs, ys = xs[sel], ys[sel]
+ if len(xs):
+ corners = np.stack([xs, ys], axis=1).reshape(-1, 1, 2).astype(np.float32)
+ # 'random' → corners stays None, falls through to random fill
+
+ pts: List[Tuple[int, int]] = []
+ if corners is not None:
+ for c in corners:
+ cx, cy = int(c[0][0]), int(c[0][1])
+ if half <= cx < w - half and half <= cy < h - half:
+ pts.append((cx, cy))
+ if len(pts) >= n:
+ break
+
+ while len(pts) < n:
+ pts.append((random.randint(half, w - half - 1),
+ random.randint(half, h - half - 1)))
+ return pts[:n]
+
+
+# ---------------------------------------------------------------------------
+# Dataset
+# ---------------------------------------------------------------------------
+
+class CNNv3Dataset(Dataset):
+ """Loads CNN v3 samples from dataset/full/ or dataset/simple/ directories.
+
+ Patch mode (default): extracts patch_size×patch_size crops centred on
+ salient points detected from albedo. Points are pre-cached at init.
+
+ Full-image mode (--full-image): resizes entire image to image_size×image_size.
+
+ Returns (feat, cond, target):
+ feat: (20, H, W) f32
+ cond: (5,) f32 FiLM conditioning (random when augment=True)
+ target: (4, H, W) f32 RGBA [0,1]
+ """
+
+ def __init__(self, dataset_dir: str,
+ input_mode: str = 'simple',
+ patch_size: int = 64,
+ patches_per_image: int = 256,
+ image_size: int = 256,
+ full_image: bool = False,
+ channel_dropout_p: float = 0.3,
+ detector: str = 'harris',
+ augment: bool = True):
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.image_size = image_size
+ self.full_image = full_image
+ self.channel_dropout_p = channel_dropout_p
+ self.detector = detector
+ self.augment = augment
+
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
+ if not self.samples:
+ raise RuntimeError(f"No samples found in {search_dir}")
+
+ # Pre-cache salient patch centres (albedo-only load — cheap)
+ self._patch_centers: List[List[Tuple[int, int]]] = []
+ if not full_image:
+ print(f"[CNNv3Dataset] Detecting salient points "
+ f"(detector={detector}, patch={patch_size}×{patch_size}) …")
+ for sd in self.samples:
+ pts = detect_salient_points(
+ load_rgb(sd / 'albedo.png'),
+ patches_per_image, detector, patch_size)
+ self._patch_centers.append(pts)
+
+ print(f"[CNNv3Dataset] mode={input_mode} samples={len(self.samples)} "
+ f"patch={patch_size} full_image={full_image}")
+
+ def __len__(self):
+ if self.full_image:
+ return len(self.samples)
+ return len(self.samples) * self.patches_per_image
+
+ def _load_sample(self, sd: Path):
+ albedo = load_rgb(sd / 'albedo.png')
+ normal = load_rg(sd / 'normal.png')
+ depth = load_depth16(sd / 'depth.png')
+ matid = load_gray(sd / 'matid.png')
+ shadow = load_gray(sd / 'shadow.png')
+ transp = load_gray(sd / 'transp.png')
+ target = np.asarray(
+ Image.open(sd / 'target.png').convert('RGBA'),
+ dtype=np.float32) / 255.0
+ return albedo, normal, depth, matid, shadow, transp, target
+
+ def __getitem__(self, idx):
+ if self.full_image:
+ sample_idx = idx
+ sd = self.samples[idx]
+ else:
+ sample_idx = idx // self.patches_per_image
+ sd = self.samples[sample_idx]
+
+ albedo, normal, depth, matid, shadow, transp, target = self._load_sample(sd)
+ h, w = albedo.shape[:2]
+
+ if self.full_image:
+ sz = self.image_size
+
+ def _resize_rgb(a):
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
+ return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
+
+ def _resize_gray(a):
+ img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L')
+ return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
+
+ albedo = _resize_rgb(albedo)
+ normal = _resize_rgb(np.concatenate(
+ [normal, np.zeros_like(normal[..., :1])], -1))[..., :2]
+ depth = _resize_gray(depth)
+ matid = _resize_gray(matid)
+ shadow = _resize_gray(shadow)
+ transp = _resize_gray(transp)
+ target = _resize_rgb(target)
+ else:
+ ps = self.patch_size
+ half = ps // 2
+ cx, cy = self._patch_centers[sample_idx][idx % self.patches_per_image]
+ cx = max(half, min(cx, w - half))
+ cy = max(half, min(cy, h - half))
+ sl = (slice(cy - half, cy - half + ps), slice(cx - half, cx - half + ps))
+
+ albedo = albedo[sl]
+ normal = normal[sl]
+ depth = depth[sl]
+ matid = matid[sl]
+ shadow = shadow[sl]
+ transp = transp[sl]
+ target = target[sl]
+
+ feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
+
+ if self.augment:
+ feat = apply_channel_dropout(feat,
+ p_geom=self.channel_dropout_p,
+ p_context=self.channel_dropout_p * 0.67,
+ p_temporal=0.5)
+ cond = np.random.rand(5).astype(np.float32)
+ else:
+ cond = np.zeros(5, dtype=np.float32)
+
+ return (torch.from_numpy(feat).permute(2, 0, 1), # (20,H,W)
+ torch.from_numpy(cond), # (5,)
+ torch.from_numpy(target).permute(2, 0, 1)) # (4,H,W)
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
new file mode 100644
index 0000000..ed925e6
--- /dev/null
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python3
+"""CNN v3 Training Script — U-Net + FiLM
+
+Architecture:
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ output sigmoid → RGBA
+
+FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
+ 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
+
+Weight budget: ~3.9 KB f16 (fits ≤6 KB target)
+"""
+
+import argparse
+import time
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from cnn_v3_utils import CNNv3Dataset, N_FEATURES
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
+ """Per-channel affine: gamma*x + beta. gamma/beta: (B,C) broadcast over H,W."""
+ return gamma[:, :, None, None] * x + beta[:, :, None, None]
+
+
+class CNNv3(nn.Module):
+ """U-Net + FiLM conditioning.
+
+ enc_channels: [c0, c1] channel counts per encoder level, default [4, 8]
+ film_cond_dim: FiLM conditioning input size, default 5
+ """
+
+ def __init__(self, enc_channels=None, film_cond_dim: int = 5):
+ super().__init__()
+ if enc_channels is None:
+ enc_channels = [4, 8]
+ assert len(enc_channels) == 2, "Only 2-level U-Net supported"
+ c0, c1 = enc_channels
+
+ self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
+ self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
+ self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
+ self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
+
+ film_out = 2 * (c0 + c1 + c0 + 4) # γ+β for enc0, enc1, dec1, dec0
+ self.film_mlp = nn.Sequential(
+ nn.Linear(film_cond_dim, 16),
+ nn.ReLU(),
+ nn.Linear(16, film_out),
+ )
+ self.enc_channels = enc_channels
+
+ def _split_film(self, film: torch.Tensor):
+ c0, c1 = self.enc_channels
+ parts = torch.split(film, [c0, c0, c1, c1, c0, c0, 4, 4], dim=-1)
+ return parts # g_enc0, b_enc0, g_enc1, b_enc1, g_dec1, b_dec1, g_dec0, b_dec0
+
+ def forward(self, feat: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ """feat: (B,20,H,W) cond: (B,5) → (B,4,H,W) RGBA [0,1]"""
+ g0, b0, g1, b1, gd1, bd1, gd0, bd0 = self._split_film(self.film_mlp(cond))
+
+ skip0 = F.relu(film_apply(self.enc0(feat), g0, b0))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(film_apply(self.enc1(x), g1, b1))
+
+ x = F.relu(self.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(film_apply(self.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ), gd1, bd1))
+
+ x = F.relu(film_apply(self.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ), gd0, bd0))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Training
+# ---------------------------------------------------------------------------
+
+def train(args):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ enc_channels = [int(c) for c in args.enc_channels.split(',')]
+ print(f"Device: {device}")
+
+ dataset = CNNv3Dataset(
+ dataset_dir=args.input,
+ input_mode=args.input_mode,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ image_size=args.image_size,
+ full_image=args.full_image,
+ channel_dropout_p=args.channel_dropout_p,
+ detector=args.detector,
+ augment=True,
+ )
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=0, drop_last=False)
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=args.film_cond_dim).to(device)
+ nparams = sum(p.numel() for p in model.parameters())
+ print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
+ f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+ ckpt_dir = Path(args.checkpoint_dir)
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nTraining {args.epochs} epochs batch={args.batch_size} lr={args.lr}")
+ start = time.time()
+ avg_loss = float('nan')
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+ n_batches = 0
+
+ for feat, cond, target in loader:
+ feat, cond, target = feat.to(device), cond.to(device), target.to(device)
+ optimizer.zero_grad()
+ loss = criterion(model(feat, cond), target)
+ loss.backward()
+ optimizer.step()
+ epoch_loss += loss.item()
+ n_batches += 1
+
+ avg_loss = epoch_loss / max(n_batches, 1)
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | "
+ f"{time.time()-start:.0f}s", end='', flush=True)
+
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print()
+ ckpt = ckpt_dir / f"checkpoint_epoch_{epoch}.pth"
+ torch.save(_checkpoint(model, optimizer, epoch, avg_loss, args), ckpt)
+ print(f" → {ckpt}")
+
+ print()
+ final = ckpt_dir / f"checkpoint_epoch_{args.epochs}.pth"
+ torch.save(_checkpoint(model, optimizer, args.epochs, avg_loss, args), final)
+ print(f"Final checkpoint: {final}")
+ print(f"Done. {time.time()-start:.1f}s")
+ return model
+
+
+def _checkpoint(model, optimizer, epoch, loss, args):
+ return {
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': loss,
+ 'config': {
+ 'enc_channels': [int(c) for c in args.enc_channels.split(',')],
+ 'film_cond_dim': args.film_cond_dim,
+ 'input_mode': args.input_mode,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# CLI
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
+
+ # Dataset
+ p.add_argument('--input', default='training/dataset',
+ help='Dataset root (contains full/ or simple/ subdirs)')
+ p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
+ help='simple=photo samples full=Blender G-buffer samples')
+ p.add_argument('--channel-dropout-p', type=float, default=0.3,
+ help='Dropout prob for geometric channels (default 0.3)')
+
+ # Patch / full-image mode
+ p.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize to --image-size)')
+ p.add_argument('--image-size', type=int, default=256,
+ help='Full-image resize target (default 256)')
+ p.add_argument('--patch-size', type=int, default=64,
+ help='Patch size (default 64)')
+ p.add_argument('--patches-per-image', type=int, default=256,
+ help='Patches per image per epoch (default 256)')
+ p.add_argument('--detector', default='harris',
+ choices=['harris', 'shi-tomasi', 'fast', 'gradient', 'random'],
+ help='Salient point detector (default harris)')
+
+ # Model
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels, comma-separated (default 4,8)')
+ p.add_argument('--film-cond-dim', type=int, default=5,
+ help='FiLM conditioning input dim (default 5)')
+
+ # Training
+ p.add_argument('--epochs', type=int, default=200)
+ p.add_argument('--batch-size', type=int, default=16)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--checkpoint-dir', default='checkpoints')
+ p.add_argument('--checkpoint-every', type=int, default=50,
+ help='Save checkpoint every N epochs (0=disable)')
+
+ train(p.parse_args())
+
+
+if __name__ == '__main__':
+ main()