diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
| commit | ce6e5b99f26e4e7c69a3cacf360bd0d492de928c (patch) | |
| tree | a8d64b33a7ea1109b6b7e1043ced946cac416756 /cnn_v3/training | |
| parent | 8b4d7a49f038d7e849e6764dcc3abd1e1be01061 (diff) | |
feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG
- Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2):
effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights)
- Add Sobel edge loss in training (--edge-loss-weight, default 0.1)
- Add FiLM 2-phase training: freeze MLP for warmup epochs then
unfreeze at lr×0.1 (--film-warmup-epochs, default 50)
- Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B)
- Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts
- Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04
- Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py)
- Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed
handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed.
Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3).
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 16 | ||||
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 72 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 67 |
3 files changed, 94 insertions, 61 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index edf76e2..78f5f25 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -15,8 +15,8 @@ Outputs <output_dir>/cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). - Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) - = 1964 f16 values = 982 u32 = 3928 bytes. + Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292) + = 2476 f16 values = 1238 u32 = 4952 bytes. <output_dir>/cnn_v3_film_mlp.bin FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). @@ -48,13 +48,13 @@ from train_cnn_v3 import CNNv3 # cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) # cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS -# = 1964 +# = 2476 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 640971c..2eb889c 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4 ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 -BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72 +BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2) DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 @@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS -TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check -# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount -# ENC0: 20*4*9 + 4 = 720+4 = 724 -# ENC1: 4*8*9 + 8 = 288+8 = 296 -# BN: 8*8*1 + 8 = 64+8 = 72 -# DEC1: 16*4*9 + 4 = 576+4 = 580 -# DEC0: 8*4*9 + 4 = 288+4 = 292 -# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck. -# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4 -# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o<DEC1_OUT, i<DEC1_IN -# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4 -# Total DEC1 = 580. OK that's right. -# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064? -# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK. -# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964. -# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table: -# enc0: 20*4*9=720 +4 = 724 -# enc1: 4*8*9=288 +8 = 296 -# bottleneck: 8*8*1=64 +8 = 72 -# dec1: 16*4*9=576 +4 = 580 -# dec0: 8*4*9=288 +4 = 292 -# Total = 724+296+72+580+292 = 1964 -# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc. -# I'll use the correct value derived from the formulas: 1964. +TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS +# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72) # --------------------------------------------------------------------------- # Helpers @@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): def bottleneck_forward(enc1, w): """ - AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU + AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU → rgba32uint (f16, quarter-res). No FiLM. enc1: (hH, hW, 8) f32 — half-res + Matches cnn_v3_bottleneck.wgsl exactly. """ hH, hW = enc1.shape[:2] qH, qW = hH // 2, hW // 2 wo = BN_OFFSET - # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL) - avg = np.zeros((qH, qW, BN_IN), dtype=np.float32) - for qy in range(qH): - for qx in range(qW): - s = np.zeros(BN_IN, dtype=np.float32) - for dy in range(2): - for dx in range(2): - hy = min(qy * 2 + dy, hH - 1) - hx = min(qx * 2 + dx, hW - 1) - s += enc1[hy, hx, :] - avg[qy, qx, :] = s * 0.25 + def load_enc1_avg(qy, qx): + """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL).""" + if qy < 0 or qx < 0 or qy >= qH or qx >= qW: + return np.zeros(BN_IN, dtype=np.float32) + s = np.zeros(BN_IN, dtype=np.float32) + for dy in range(2): + for dx in range(2): + hy = min(qy * 2 + dy, hH - 1) + hx = min(qx * 2 + dx, hW - 1) + s += enc1[hy, hx, :] + return s * 0.25 - # 1x1 conv (no spatial loop, just channel dot-product) + # 3x3 conv with dilation=2 in quarter-res space out = np.zeros((qH, qW, BN_OUT), dtype=np.float32) for o in range(BN_OUT): - bias = get_w(w, wo, BN_OUT * BN_IN + o) - s = np.full((qH, qW), bias, dtype=np.float32) - for i in range(BN_IN): - wv = get_w(w, wo, o * BN_IN + i) - s += wv * avg[:, :, i] - out[:, :, o] = np.maximum(0.0, s) + bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o) + for qy in range(qH): + for qx in range(qW): + s = bias + for ky in range(-1, 2): + for kx in range(-1, 2): + feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2 + ki = (ky + 1) * 3 + (kx + 1) + for i in range(BN_IN): + s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i] + out[qy, qx, o] = max(0.0, s) return np.float16(out).astype(np.float32) # pack2x16float boundary diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 31cfd9d..c790495 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -6,17 +6,21 @@ """CNN v3 Training Script — U-Net + FiLM Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W output sigmoid → RGBA FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) -Weight budget: ~5.4 KB f16 (fits ≤6 KB target) +Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) + +Training improvements: + --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) + --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) """ import argparse @@ -56,7 +60,7 @@ class CNNv3(nn.Module): self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) - self.bottleneck = nn.Conv2d(c1, c1, 1) + self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2) self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 @@ -96,6 +100,24 @@ class CNNv3(nn.Module): # --------------------------------------------------------------------------- +# Loss +# --------------------------------------------------------------------------- + +def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Gradient loss via Sobel filters. No VGG dependency. + pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device.""" + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + B, C, H, W = pred.shape + p = pred.view(B * C, 1, H, W) + t = target.view(B * C, 1, H, W) + return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) + + F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1))) + + +# --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- @@ -129,11 +151,20 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first. + film_warmup = args.film_warmup_epochs + if film_warmup > 0: + for p in model.film_mlp.parameters(): + p.requires_grad = False + print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)") + + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) criterion = nn.MSELoss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 + film_unfrozen = (film_warmup == 0) if args.resume: ckpt_path = Path(args.resume) @@ -168,6 +199,15 @@ def train(args): for epoch in range(start_epoch, args.epochs + 1): if interrupted: break + + # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR. + if not film_unfrozen and epoch > film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1) + film_unfrozen = True + print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})") + model.train() epoch_loss = 0.0 n_batches = 0 @@ -177,7 +217,10 @@ def train(args): break feat, cond, target = feat.to(device), cond.to(device), target.to(device) optimizer.zero_grad() - loss = criterion(model(feat, cond), target) + pred = model(feat, cond) + loss = criterion(pred, target) + if args.edge_loss_weight > 0.0: + loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() @@ -215,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args): 'enc_channels': [int(c) for c in args.enc_channels.split(',')], 'film_cond_dim': args.film_cond_dim, 'input_mode': args.input_mode, + 'edge_loss_weight': args.edge_loss_weight, + 'film_warmup_epochs': args.film_warmup_epochs, }, } @@ -266,6 +311,10 @@ def main(): help='Save checkpoint every N epochs (0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') + p.add_argument('--edge-loss-weight', type=float, default=0.1, + help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--film-warmup-epochs', type=int, default=50, + help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') train(p.parse_args()) |
