summaryrefslogtreecommitdiff
path: root/cnn_v3/training
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training')
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py51
-rw-r--r--cnn_v3/training/gen_test_vectors.py91
-rw-r--r--cnn_v3/training/infer_cnn_v3.py4
-rw-r--r--cnn_v3/training/train_cnn_v3.py28
4 files changed, 94 insertions, 80 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 78f5f25..2fa83d1 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -15,12 +15,12 @@ Outputs
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
- Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
- = 2476 f16 values = 1238 u32 = 4952 bytes.
+ Layout: enc0 (1448) | enc1 (1168) | bottleneck (2320) | dec1 (2312) | dec0 (580)
+ = 7828 f16 values = 3914 u32 = 15656 bytes.
<output_dir>/cnn_v3_film_mlp.bin
- FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
- = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes.
+ FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×72) L1_b (72).
+ = 5*16 + 16 + 16*72 + 72 = 80 + 16 + 1152 + 72 = 1320 f32 = 5280 bytes.
For future CPU-side MLP inference in CNNv3Effect::set_film_params().
Usage
@@ -44,17 +44,19 @@ sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
-# Weight layout constants — must stay in sync with:
-# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
-# cnn_v3/training/gen_test_vectors.py (same constants)
+# Weight layout helpers — derived from enc_channels at runtime.
+# Must stay in sync with cnn_v3/src/cnn_v3_effect.cc and gen_test_vectors.py.
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
-BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
-TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
-# = 2476
+N_IN = 20 # feature input channels (fixed)
+
+def weight_counts(enc_channels):
+ c0, c1 = enc_channels
+ enc0 = N_IN * c0 * 9 + c0
+ enc1 = c0 * c1 * 9 + c1
+ bn = c1 * c1 * 9 + c1
+ dec1 = (c1 * 2) * c0 * 9 + c0
+ dec0 = (c0 * 2) * 4 * 9 + 4
+ return enc0, enc1, bn, dec1, dec0
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
@@ -86,7 +88,7 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
cfg = ckpt.get('config', {})
- enc_channels = cfg.get('enc_channels', [4, 8])
+ enc_channels = cfg.get('enc_channels', [8, 16])
film_cond_dim = cfg.get('film_cond_dim', 5)
model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
@@ -102,13 +104,18 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
# -----------------------------------------------------------------------
# 1. CNN conv weights → cnn_v3_weights.bin
# -----------------------------------------------------------------------
+ enc0_w, enc1_w, bn_w, dec1_w, dec0_w = weight_counts(enc_channels)
+ total_f16 = enc0_w + enc1_w + bn_w + dec1_w + dec0_w
layers = [
- ('enc0', ENC0_WEIGHTS),
- ('enc1', ENC1_WEIGHTS),
- ('bottleneck', BN_WEIGHTS),
- ('dec1', DEC1_WEIGHTS),
- ('dec0', DEC0_WEIGHTS),
+ ('enc0', enc0_w),
+ ('enc1', enc1_w),
+ ('bottleneck', bn_w),
+ ('dec1', dec1_w),
+ ('dec0', dec0_w),
]
+ print(f" Weight layout: enc0={enc0_w} enc1={enc1_w} bn={bn_w} "
+ f"dec1={dec1_w} dec0={dec0_w} total={total_f16} f16 "
+ f"({total_f16*2/1024:.1f} KB)")
all_f16 = []
for name, expected in layers:
@@ -119,13 +126,13 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
all_f16.append(chunk)
flat_f16 = np.concatenate(all_f16)
- assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}"
+ assert len(flat_f16) == total_f16, f"total mismatch: {len(flat_f16)} != {total_f16}"
packed_u32 = pack_weights_u32(flat_f16)
weights_path = out / 'cnn_v3_weights.bin'
packed_u32.astype('<u4').tofile(weights_path) # little-endian u32
print(f"\ncnn_v3_weights.bin")
- print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
+ print(f" {total_f16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes")
print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})")
# -----------------------------------------------------------------------
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 2eb889c..cdda5a5 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -15,17 +15,17 @@ import argparse
# Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants)
# ---------------------------------------------------------------------------
-ENC0_IN, ENC0_OUT = 20, 4
-ENC1_IN, ENC1_OUT = 4, 8
-BN_IN, BN_OUT = 8, 8
-DEC1_IN, DEC1_OUT = 16, 4
-DEC0_IN, DEC0_OUT = 8, 4
+ENC0_IN, ENC0_OUT = 20, 8
+ENC1_IN, ENC1_OUT = 8, 16
+BN_IN, BN_OUT = 16, 16
+DEC1_IN, DEC1_OUT = 32, 8
+DEC0_IN, DEC0_OUT = 16, 4
-ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724
-ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296
-BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2)
-DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580
-DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292
+ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 1448
+ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 1168
+BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 2320 (3x3 dilation=2)
+DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 2312
+DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 580
ENC0_OFFSET = 0
ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
@@ -33,7 +33,7 @@ BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS
-# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72)
+# 1448 + 1168 + 2320 + 2312 + 580 = 7828
# ---------------------------------------------------------------------------
# Helpers
@@ -50,11 +50,11 @@ def get_w(w_f32, base, idx):
def enc0_forward(feat0, feat1, w, gamma, beta):
"""
- Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored).
+ Conv(20->8, 3x3, zero-pad) + FiLM + ReLU → rgba32uint (pack2x16float, f16 stored).
feat0: (H, W, 8) f32 — channels from unpack2x16float(feat_tex0)
feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1)
- gamma, beta: (ENC0_OUT,) f32 — FiLM params
- Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary)
+ gamma, beta: (ENC0_OUT=8,) f32 — FiLM params
+ Returns: (H, W, 8) f32 — f16 precision (pack2x16float boundary)
"""
H, W = feat0.shape[:2]
wo = ENC0_OFFSET
@@ -72,14 +72,15 @@ def enc0_forward(feat0, feat1, w, gamma, beta):
s += wv * fp[ky:ky+H, kx:kx+W, i]
out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])
- return np.float16(out).astype(np.float32) # rgba16float texture boundary
+ return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint)
-def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
+def enc1_forward(enc0, w, gamma, beta):
"""
- AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU
- → rgba32uint (pack2x16float, f16 precision, half-res).
- enc0: (H, W, 4) f32 — rgba16float precision
+ AvgPool2x2(enc0, clamp-border) + Conv(8->16, 3x3, zero-pad) + FiLM + ReLU
+ → 2x rgba32uint (pack2x16float, f16 precision, half-res).
+ enc0: (H, W, 8) f32 — pack2x16float precision
+ gamma, beta: (ENC1_OUT=16,) f32 — FiLM params
"""
H, W = enc0.shape[:2]
hH, hW = H // 2, W // 2
@@ -99,8 +100,6 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
# 3x3 conv with zero-padding at half-res borders
ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant')
- gamma = np.concatenate([gamma_lo, gamma_hi])
- beta = np.concatenate([beta_lo, beta_hi])
out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32)
for o in range(ENC1_OUT):
@@ -159,10 +158,11 @@ def bottleneck_forward(enc1, w):
def dec1_forward(bn, enc1, w, gamma, beta):
"""
- NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU
- → rgba16float (half-res).
- bn: (qH, qW, 8) f32 — quarter-res bottleneck
- enc1: (hH, hW, 8) f32 — half-res skip connection
+ NearestUp2x(bn) + cat(enc1_skip) → Conv(32->8, 3x3, zero-pad) + FiLM + ReLU
+ → rgba32uint (pack2x16float, half-res).
+ bn: (qH, qW, 16) f32 — quarter-res bottleneck
+ enc1: (hH, hW, 16) f32 — half-res skip connection
+ gamma, beta: (DEC1_OUT=8,) f32 — FiLM params
"""
hH, hW = enc1.shape[:2]
qH, qW = bn.shape[:2]
@@ -188,15 +188,15 @@ def dec1_forward(bn, enc1, w, gamma, beta):
s += wv * fp[ky:ky+hH, kx:kx+hW, i]
out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])
- return np.float16(out).astype(np.float32) # rgba16float boundary
+ return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint)
def dec0_forward(dec1, enc0, w, gamma, beta):
"""
- NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
+ NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
→ rgba16float (full-res, final output).
- dec1: (hH, hW, 4) f32 — half-res
- enc0: (H, W, 4) f32 — full-res enc0 skip
+ dec1: (hH, hW, 8) f32 — half-res
+ enc0: (H, W, 8) f32 — full-res enc0 skip
"""
H, W = enc0.shape[:2]
hH, hW = dec1.shape[:2]
@@ -231,8 +231,7 @@ def forward_pass(feat0, feat1, w_f32, film):
enc0 = enc0_forward(feat0, feat1, w_f32,
film['enc0_gamma'], film['enc0_beta'])
enc1 = enc1_forward(enc0, w_f32,
- film['enc1_gamma_lo'], film['enc1_gamma_hi'],
- film['enc1_beta_lo'], film['enc1_beta_hi'])
+ film['enc1_gamma'], film['enc1_beta'])
bn = bottleneck_forward(enc1, w_f32)
dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
dc0 = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])
@@ -241,16 +240,14 @@ def forward_pass(feat0, feat1, w_f32, film):
def identity_film():
return {
- 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32),
- 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32),
- 'enc1_gamma_lo': np.ones(4, dtype=np.float32),
- 'enc1_gamma_hi': np.ones(4, dtype=np.float32),
- 'enc1_beta_lo': np.zeros(4, dtype=np.float32),
- 'enc1_beta_hi': np.zeros(4, dtype=np.float32),
- 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32),
- 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32),
- 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32),
- 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32),
+ 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), # 8
+ 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), # 8
+ 'enc1_gamma': np.ones(ENC1_OUT, dtype=np.float32), # 16
+ 'enc1_beta': np.zeros(ENC1_OUT, dtype=np.float32), # 16
+ 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), # 8
+ 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), # 8
+ 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), # 4
+ 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), # 4
}
@@ -324,8 +321,7 @@ def generate_vectors(W=8, H=8, seed=42):
enc0 = enc0_forward(feat0, feat1, w_f32,
film['enc0_gamma'], film['enc0_beta'])
enc1 = enc1_forward(enc0, w_f32,
- film['enc1_gamma_lo'], film['enc1_gamma_hi'],
- film['enc1_beta_lo'], film['enc1_beta_hi'])
+ film['enc1_gamma'], film['enc1_beta'])
bn = bottleneck_forward(enc1, w_f32)
dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
out = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])
@@ -333,8 +329,9 @@ def generate_vectors(W=8, H=8, seed=42):
feat0_u32 = pack_feat0_rgba32uint(feat0, H, W)
feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W)
w_u32 = pack_weights_u32(w_f16)
+ # enc0: 8ch stored as pack2x16float → H*W*8 f16 values
enc0_u16 = np.float16(enc0.reshape(-1)).view(np.uint16)
- # dec1 is half-res (hH x hW x 4); store as-is
+ # dec1: 8ch half-res stored as pack2x16float → (H/2)*(W/2)*8 f16 values
dc1_u16 = np.float16(dc1.reshape(-1)).view(np.uint16)
out_u16 = np.float16(out.reshape(-1)).view(np.uint16) # raw f16 bits
@@ -386,11 +383,15 @@ def emit_c_header(v):
lines.append("};")
lines.append("")
+ lines.append(f"// ENC0_OUT={ENC0_OUT} ENC1_OUT={ENC1_OUT} BN={BN_OUT} DEC1_OUT={DEC1_OUT} DEC0_OUT={DEC0_OUT}")
+ lines.append(f"// TOTAL_F16={TOTAL_F16} (enc_channels=[{ENC0_OUT},{ENC1_OUT}])")
+ lines.append("")
array_u32("kCnnV3TestFeat0U32", v['feat0_u32'])
array_u32("kCnnV3TestFeat1U32", v['feat1_u32'])
array_u32("kCnnV3TestWeightsU32", v['w_u32'])
+ lines.append(f"// enc0: {ENC0_OUT}ch rgba32uint → W*H*{ENC0_OUT} f16 values")
array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16'])
- lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}")
+ lines.append(f"// dec1: {DEC1_OUT}ch rgba32uint half-res → (W/2)*(H/2)*{DEC1_OUT} f16 values")
array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16'])
array_u16("kCnnV3ExpectedOutputU16", v['out_u16'])
return "\n".join(lines)
diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py
index ca1c72a..b0fe9e6 100644
--- a/cnn_v3/training/infer_cnn_v3.py
+++ b/cnn_v3/training/infer_cnn_v3.py
@@ -129,8 +129,8 @@ def main():
p.add_argument('output', help='Output PNG')
p.add_argument('--checkpoint', '-c', metavar='CKPT',
help='Path to .pth checkpoint (auto-finds latest if omitted)')
- p.add_argument('--enc-channels', default='4,8',
- help='Encoder channels (default: 4,8 — must match checkpoint)')
+ p.add_argument('--enc-channels', default='8,16',
+ help='Encoder channels (default: 8,16 — must match checkpoint)')
p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5,
help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)')
p.add_argument('--identity-film', action='store_true',
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index c61c360..5b6a0be 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -5,18 +5,18 @@
# ///
"""CNN v3 Training Script — U-Net + FiLM
-Architecture:
- enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
- enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
- bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4
- dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
- dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+Architecture (enc_channels=[8,16]):
+ enc0 Conv(20→8, 3×3) + FiLM + ReLU H×W rgba32uint (8ch)
+ enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split)
+ bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split)
+ dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch)
+ dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch)
output sigmoid → RGBA
-FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
- 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
+FiLM MLP: Linear(5→16) → ReLU → Linear(16→72)
+ 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4)
-Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target)
+Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB
Training improvements:
--edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
@@ -47,14 +47,14 @@ def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torc
class CNNv3(nn.Module):
"""U-Net + FiLM conditioning.
- enc_channels: [c0, c1] channel counts per encoder level, default [4, 8]
+ enc_channels: [c0, c1] channel counts per encoder level, default [8, 16]
film_cond_dim: FiLM conditioning input size, default 5
"""
def __init__(self, enc_channels=None, film_cond_dim: int = 5):
super().__init__()
if enc_channels is None:
- enc_channels = [4, 8]
+ enc_channels = [8, 16]
assert len(enc_channels) == 2, "Only 2-level U-Net supported"
c0, c1 = enc_channels
@@ -227,6 +227,10 @@ def train(args):
optimizer.zero_grad()
pred = model(feat, cond)
loss = criterion(pred, target)
+ if args.multiscale_weight > 0.0:
+ for scale in [2, 4]:
+ loss = loss + args.multiscale_weight * criterion(
+ F.avg_pool2d(pred, scale), F.avg_pool2d(target, scale))
if args.edge_loss_weight > 0.0:
loss = loss + args.edge_loss_weight * sobel_loss(pred, target)
loss.backward()
@@ -321,6 +325,8 @@ def main():
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
p.add_argument('--edge-loss-weight', type=float, default=0.1,
help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)')
+ p.add_argument('--multiscale-weight', type=float, default=0.5,
+ help='Weight per pyramid level for multi-scale MSE (default 0.5; 0=disable)')
p.add_argument('--film-warmup-epochs', type=int, default=50,
help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)')