summaryrefslogtreecommitdiff
path: root/cnn_v3/training/train_cnn_v3.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/train_cnn_v3.py')
-rw-r--r--cnn_v3/training/train_cnn_v3.py74
1 files changed, 65 insertions, 9 deletions
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index de10d6a..c790495 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -6,17 +6,21 @@
"""CNN v3 Training Script — U-Net + FiLM
Architecture:
- enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
- enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
- bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
- dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
- dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
output sigmoid → RGBA
FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
-Weight budget: ~5.4 KB f16 (fits ≤6 KB target)
+Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target)
+
+Training improvements:
+ --edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
+ --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
"""
import argparse
@@ -56,7 +60,7 @@ class CNNv3(nn.Module):
self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
- self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2)
self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
@@ -96,6 +100,24 @@ class CNNv3(nn.Module):
# ---------------------------------------------------------------------------
+# Loss
+# ---------------------------------------------------------------------------
+
+def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Gradient loss via Sobel filters. No VGG dependency.
+ pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device."""
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ B, C, H, W = pred.shape
+ p = pred.view(B * C, 1, H, W)
+ t = target.view(B * C, 1, H, W)
+ return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) +
+ F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1)))
+
+
+# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
@@ -104,6 +126,10 @@ def train(args):
enc_channels = [int(c) for c in args.enc_channels.split(',')]
print(f"Device: {device}")
+ if args.single_sample:
+ args.full_image = True
+ args.batch_size = 1
+
dataset = CNNv3Dataset(
dataset_dir=args.input,
input_mode=args.input_mode,
@@ -115,6 +141,7 @@ def train(args):
detector=args.detector,
augment=True,
patch_search_window=args.patch_search_window,
+ single_sample=args.single_sample,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -124,11 +151,20 @@ def train(args):
print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first.
+ film_warmup = args.film_warmup_epochs
+ if film_warmup > 0:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = False
+ print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)")
+
+ optimizer = torch.optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
criterion = nn.MSELoss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1
+ film_unfrozen = (film_warmup == 0)
if args.resume:
ckpt_path = Path(args.resume)
@@ -163,6 +199,15 @@ def train(args):
for epoch in range(start_epoch, args.epochs + 1):
if interrupted:
break
+
+ # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR.
+ if not film_unfrozen and epoch > film_warmup:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = True
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1)
+ film_unfrozen = True
+ print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})")
+
model.train()
epoch_loss = 0.0
n_batches = 0
@@ -172,7 +217,10 @@ def train(args):
break
feat, cond, target = feat.to(device), cond.to(device), target.to(device)
optimizer.zero_grad()
- loss = criterion(model(feat, cond), target)
+ pred = model(feat, cond)
+ loss = criterion(pred, target)
+ if args.edge_loss_weight > 0.0:
+ loss = loss + args.edge_loss_weight * sobel_loss(pred, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
@@ -210,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args):
'enc_channels': [int(c) for c in args.enc_channels.split(',')],
'film_cond_dim': args.film_cond_dim,
'input_mode': args.input_mode,
+ 'edge_loss_weight': args.edge_loss_weight,
+ 'film_warmup_epochs': args.film_warmup_epochs,
},
}
@@ -222,6 +272,8 @@ def main():
p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
# Dataset
+ p.add_argument('--single-sample', default='', metavar='DIR',
+ help='Train on a single sample directory; implies --full-image and --batch-size 1')
p.add_argument('--input', default='training/dataset',
help='Dataset root (contains full/ or simple/ subdirs)')
p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
@@ -259,6 +311,10 @@ def main():
help='Save checkpoint every N epochs (0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
+ p.add_argument('--edge-loss-weight', type=float, default=0.1,
+ help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)')
+ p.add_argument('--film-warmup-epochs', type=int, default=50,
+ help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)')
train(p.parse_args())