From 8f14bdd66cb002b2f89265b2a578ad93249089c9 Mon Sep 17 00:00:00 2001 From: skal Date: Thu, 26 Mar 2026 07:03:01 +0100 Subject: feat(cnn_v3): upgrade architecture to enc_channels=[8,16] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass. --- cnn_v3/training/train_cnn_v3.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) (limited to 'cnn_v3/training/train_cnn_v3.py') diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c61c360..5b6a0be 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -5,18 +5,18 @@ # /// """CNN v3 Training Script — U-Net + FiLM -Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W +Architecture (enc_channels=[8,16]): + enc0 Conv(20→8, 3×3) + FiLM + ReLU H×W rgba32uint (8ch) + enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split) + bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split) + dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch) + dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch) output sigmoid → RGBA -FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) - 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) +FiLM MLP: Linear(5→16) → ReLU → Linear(16→72) + 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4) -Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) +Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB Training improvements: --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) @@ -47,14 +47,14 @@ def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torc class CNNv3(nn.Module): """U-Net + FiLM conditioning. - enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] + enc_channels: [c0, c1] channel counts per encoder level, default [8, 16] film_cond_dim: FiLM conditioning input size, default 5 """ def __init__(self, enc_channels=None, film_cond_dim: int = 5): super().__init__() if enc_channels is None: - enc_channels = [4, 8] + enc_channels = [8, 16] assert len(enc_channels) == 2, "Only 2-level U-Net supported" c0, c1 = enc_channels @@ -227,6 +227,10 @@ def train(args): optimizer.zero_grad() pred = model(feat, cond) loss = criterion(pred, target) + if args.multiscale_weight > 0.0: + for scale in [2, 4]: + loss = loss + args.multiscale_weight * criterion( + F.avg_pool2d(pred, scale), F.avg_pool2d(target, scale)) if args.edge_loss_weight > 0.0: loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() @@ -321,6 +325,8 @@ def main(): help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') p.add_argument('--edge-loss-weight', type=float, default=0.1, help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--multiscale-weight', type=float, default=0.5, + help='Weight per pyramid level for multi-scale MSE (default 0.5; 0=disable)') p.add_argument('--film-warmup-epochs', type=int, default=50, help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') -- cgit v1.2.3