diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
| commit | ce6e5b99f26e4e7c69a3cacf360bd0d492de928c (patch) | |
| tree | a8d64b33a7ea1109b6b7e1043ced946cac416756 /cnn_v3/training/train_cnn_v3.py | |
| parent | 8b4d7a49f038d7e849e6764dcc3abd1e1be01061 (diff) | |
feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG
- Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2):
effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights)
- Add Sobel edge loss in training (--edge-loss-weight, default 0.1)
- Add FiLM 2-phase training: freeze MLP for warmup epochs then
unfreeze at lr×0.1 (--film-warmup-epochs, default 50)
- Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B)
- Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts
- Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04
- Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py)
- Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed
handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed.
Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3).
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 67 |
1 files changed, 58 insertions, 9 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 31cfd9d..c790495 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -6,17 +6,21 @@ """CNN v3 Training Script — U-Net + FiLM Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W output sigmoid → RGBA FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) -Weight budget: ~5.4 KB f16 (fits ≤6 KB target) +Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) + +Training improvements: + --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) + --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) """ import argparse @@ -56,7 +60,7 @@ class CNNv3(nn.Module): self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) - self.bottleneck = nn.Conv2d(c1, c1, 1) + self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2) self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 @@ -96,6 +100,24 @@ class CNNv3(nn.Module): # --------------------------------------------------------------------------- +# Loss +# --------------------------------------------------------------------------- + +def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Gradient loss via Sobel filters. No VGG dependency. + pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device.""" + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + B, C, H, W = pred.shape + p = pred.view(B * C, 1, H, W) + t = target.view(B * C, 1, H, W) + return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) + + F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1))) + + +# --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- @@ -129,11 +151,20 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first. + film_warmup = args.film_warmup_epochs + if film_warmup > 0: + for p in model.film_mlp.parameters(): + p.requires_grad = False + print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)") + + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) criterion = nn.MSELoss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 + film_unfrozen = (film_warmup == 0) if args.resume: ckpt_path = Path(args.resume) @@ -168,6 +199,15 @@ def train(args): for epoch in range(start_epoch, args.epochs + 1): if interrupted: break + + # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR. + if not film_unfrozen and epoch > film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1) + film_unfrozen = True + print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})") + model.train() epoch_loss = 0.0 n_batches = 0 @@ -177,7 +217,10 @@ def train(args): break feat, cond, target = feat.to(device), cond.to(device), target.to(device) optimizer.zero_grad() - loss = criterion(model(feat, cond), target) + pred = model(feat, cond) + loss = criterion(pred, target) + if args.edge_loss_weight > 0.0: + loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() @@ -215,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args): 'enc_channels': [int(c) for c in args.enc_channels.split(',')], 'film_cond_dim': args.film_cond_dim, 'input_mode': args.input_mode, + 'edge_loss_weight': args.edge_loss_weight, + 'film_warmup_epochs': args.film_warmup_epochs, }, } @@ -266,6 +311,10 @@ def main(): help='Save checkpoint every N epochs (0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') + p.add_argument('--edge-loss-weight', type=float, default=0.1, + help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--film-warmup-epochs', type=int, default=50, + help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') train(p.parse_args()) |
