summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
commit8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch)
tree2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/training/train_cnn_v3.py
parent4ca498277b033ae10134045dae9c8c249a8d2b2b (diff)
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index c61c360..5b6a0be 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -5,18 +5,18 @@
# ///
"""CNN v3 Training Script — U-Net + FiLM
-Architecture:
- enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
- enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
- bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4
- dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
- dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+Architecture (enc_channels=[8,16]):
+ enc0 Conv(20→8, 3×3) + FiLM + ReLU H×W rgba32uint (8ch)
+ enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split)
+ bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split)
+ dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch)
+ dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch)
output sigmoid → RGBA
-FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
- 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
+FiLM MLP: Linear(5→16) → ReLU → Linear(16→72)
+ 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4)
-Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target)
+Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB
Training improvements:
--edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
@@ -47,14 +47,14 @@ def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torc
class CNNv3(nn.Module):
"""U-Net + FiLM conditioning.
- enc_channels: [c0, c1] channel counts per encoder level, default [4, 8]
+ enc_channels: [c0, c1] channel counts per encoder level, default [8, 16]
film_cond_dim: FiLM conditioning input size, default 5
"""
def __init__(self, enc_channels=None, film_cond_dim: int = 5):
super().__init__()
if enc_channels is None:
- enc_channels = [4, 8]
+ enc_channels = [8, 16]
assert len(enc_channels) == 2, "Only 2-level U-Net supported"
c0, c1 = enc_channels
@@ -227,6 +227,10 @@ def train(args):
optimizer.zero_grad()
pred = model(feat, cond)
loss = criterion(pred, target)
+ if args.multiscale_weight > 0.0:
+ for scale in [2, 4]:
+ loss = loss + args.multiscale_weight * criterion(
+ F.avg_pool2d(pred, scale), F.avg_pool2d(target, scale))
if args.edge_loss_weight > 0.0:
loss = loss + args.edge_loss_weight * sobel_loss(pred, target)
loss.backward()
@@ -321,6 +325,8 @@ def main():
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
p.add_argument('--edge-loss-weight', type=float, default=0.1,
help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)')
+ p.add_argument('--multiscale-weight', type=float, default=0.5,
+ help='Weight per pyramid level for multi-scale MSE (default 0.5; 0=disable)')
p.add_argument('--film-warmup-epochs', type=int, default=50,
help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)')