diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
| commit | 8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch) | |
| tree | 2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/training/train_cnn_v3.py | |
| parent | 4ca498277b033ae10134045dae9c8c249a8d2b2b (diff) | |
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch,
dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB).
FiLM MLP output 40→72 params (L1: 16×40→16×72).
16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck).
enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch).
GBUF_RGBA32UINT node gains CopySrc for parity test readback.
- WGSL shaders: all 5 passes rewritten for new channel counts
- C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs
- Web tool (shaders.js + tester.js): matching texture formats and bindings
- Parity test: readback_rgba32uint_8ch helper, updated vector counts
- Training scripts: default enc_channels=[8,16], updated docstrings
- Docs + architecture PNG regenerated
handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web
tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c61c360..5b6a0be 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -5,18 +5,18 @@ # /// """CNN v3 Training Script — U-Net + FiLM -Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W +Architecture (enc_channels=[8,16]): + enc0 Conv(20→8, 3×3) + FiLM + ReLU H×W rgba32uint (8ch) + enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split) + bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split) + dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch) + dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch) output sigmoid → RGBA -FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) - 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) +FiLM MLP: Linear(5→16) → ReLU → Linear(16→72) + 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4) -Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) +Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB Training improvements: --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) @@ -47,14 +47,14 @@ def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torc class CNNv3(nn.Module): """U-Net + FiLM conditioning. - enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] + enc_channels: [c0, c1] channel counts per encoder level, default [8, 16] film_cond_dim: FiLM conditioning input size, default 5 """ def __init__(self, enc_channels=None, film_cond_dim: int = 5): super().__init__() if enc_channels is None: - enc_channels = [4, 8] + enc_channels = [8, 16] assert len(enc_channels) == 2, "Only 2-level U-Net supported" c0, c1 = enc_channels @@ -227,6 +227,10 @@ def train(args): optimizer.zero_grad() pred = model(feat, cond) loss = criterion(pred, target) + if args.multiscale_weight > 0.0: + for scale in [2, 4]: + loss = loss + args.multiscale_weight * criterion( + F.avg_pool2d(pred, scale), F.avg_pool2d(target, scale)) if args.edge_loss_weight > 0.0: loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() @@ -321,6 +325,8 @@ def main(): help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') p.add_argument('--edge-loss-weight', type=float, default=0.1, help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--multiscale-weight', type=float, default=0.5, + help='Weight per pyramid level for multi-scale MSE (default 0.5; 0=disable)') p.add_argument('--film-warmup-epochs', type=int, default=50, help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') |
