diff options
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 51 | ||||
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 91 | ||||
| -rw-r--r-- | cnn_v3/training/infer_cnn_v3.py | 4 | ||||
| -rw-r--r-- | cnn_v3/training/train_cnn_v3.py | 28 |
4 files changed, 94 insertions, 80 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index 78f5f25..2fa83d1 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -15,12 +15,12 @@ Outputs <output_dir>/cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). - Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292) - = 2476 f16 values = 1238 u32 = 4952 bytes. + Layout: enc0 (1448) | enc1 (1168) | bottleneck (2320) | dec1 (2312) | dec0 (580) + = 7828 f16 values = 3914 u32 = 15656 bytes. <output_dir>/cnn_v3_film_mlp.bin - FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). - = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes. + FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×72) L1_b (72). + = 5*16 + 16 + 16*72 + 72 = 80 + 16 + 1152 + 72 = 1320 f32 = 5280 bytes. For future CPU-side MLP inference in CNNv3Effect::set_film_params(). Usage @@ -44,17 +44,19 @@ sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- -# Weight layout constants — must stay in sync with: -# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) -# cnn_v3/training/gen_test_vectors.py (same constants) +# Weight layout helpers — derived from enc_channels at runtime. +# Must stay in sync with cnn_v3/src/cnn_v3_effect.cc and gen_test_vectors.py. # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 -BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 -TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS -# = 2476 +N_IN = 20 # feature input channels (fixed) + +def weight_counts(enc_channels): + c0, c1 = enc_channels + enc0 = N_IN * c0 * 9 + c0 + enc1 = c0 * c1 * 9 + c1 + bn = c1 * c1 * 9 + c1 + dec1 = (c1 * 2) * c0 * 9 + c0 + dec0 = (c0 * 2) * 4 * 9 + 4 + return enc0, enc1, bn, dec1, dec0 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: @@ -86,7 +88,7 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) cfg = ckpt.get('config', {}) - enc_channels = cfg.get('enc_channels', [4, 8]) + enc_channels = cfg.get('enc_channels', [8, 16]) film_cond_dim = cfg.get('film_cond_dim', 5) model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) @@ -102,13 +104,18 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: # ----------------------------------------------------------------------- # 1. CNN conv weights → cnn_v3_weights.bin # ----------------------------------------------------------------------- + enc0_w, enc1_w, bn_w, dec1_w, dec0_w = weight_counts(enc_channels) + total_f16 = enc0_w + enc1_w + bn_w + dec1_w + dec0_w layers = [ - ('enc0', ENC0_WEIGHTS), - ('enc1', ENC1_WEIGHTS), - ('bottleneck', BN_WEIGHTS), - ('dec1', DEC1_WEIGHTS), - ('dec0', DEC0_WEIGHTS), + ('enc0', enc0_w), + ('enc1', enc1_w), + ('bottleneck', bn_w), + ('dec1', dec1_w), + ('dec0', dec0_w), ] + print(f" Weight layout: enc0={enc0_w} enc1={enc1_w} bn={bn_w} " + f"dec1={dec1_w} dec0={dec0_w} total={total_f16} f16 " + f"({total_f16*2/1024:.1f} KB)") all_f16 = [] for name, expected in layers: @@ -119,13 +126,13 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: all_f16.append(chunk) flat_f16 = np.concatenate(all_f16) - assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}" + assert len(flat_f16) == total_f16, f"total mismatch: {len(flat_f16)} != {total_f16}" packed_u32 = pack_weights_u32(flat_f16) weights_path = out / 'cnn_v3_weights.bin' packed_u32.astype('<u4').tofile(weights_path) # little-endian u32 print(f"\ncnn_v3_weights.bin") - print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes") + print(f" {total_f16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes") print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})") # ----------------------------------------------------------------------- diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 2eb889c..cdda5a5 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -15,17 +15,17 @@ import argparse # Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants) # --------------------------------------------------------------------------- -ENC0_IN, ENC0_OUT = 20, 4 -ENC1_IN, ENC1_OUT = 4, 8 -BN_IN, BN_OUT = 8, 8 -DEC1_IN, DEC1_OUT = 16, 4 -DEC0_IN, DEC0_OUT = 8, 4 +ENC0_IN, ENC0_OUT = 20, 8 +ENC1_IN, ENC1_OUT = 8, 16 +BN_IN, BN_OUT = 16, 16 +DEC1_IN, DEC1_OUT = 32, 8 +DEC0_IN, DEC0_OUT = 16, 4 -ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 -ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 -BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2) -DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 -DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 +ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 1448 +ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 1168 +BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 2320 (3x3 dilation=2) +DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 2312 +DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 580 ENC0_OFFSET = 0 ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS @@ -33,7 +33,7 @@ BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS -# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72) +# 1448 + 1168 + 2320 + 2312 + 580 = 7828 # --------------------------------------------------------------------------- # Helpers @@ -50,11 +50,11 @@ def get_w(w_f32, base, idx): def enc0_forward(feat0, feat1, w, gamma, beta): """ - Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored). + Conv(20->8, 3x3, zero-pad) + FiLM + ReLU → rgba32uint (pack2x16float, f16 stored). feat0: (H, W, 8) f32 — channels from unpack2x16float(feat_tex0) feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1) - gamma, beta: (ENC0_OUT,) f32 — FiLM params - Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary) + gamma, beta: (ENC0_OUT=8,) f32 — FiLM params + Returns: (H, W, 8) f32 — f16 precision (pack2x16float boundary) """ H, W = feat0.shape[:2] wo = ENC0_OFFSET @@ -72,14 +72,15 @@ def enc0_forward(feat0, feat1, w, gamma, beta): s += wv * fp[ky:ky+H, kx:kx+W, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) - return np.float16(out).astype(np.float32) # rgba16float texture boundary + return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint) -def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): +def enc1_forward(enc0, w, gamma, beta): """ - AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU - → rgba32uint (pack2x16float, f16 precision, half-res). - enc0: (H, W, 4) f32 — rgba16float precision + AvgPool2x2(enc0, clamp-border) + Conv(8->16, 3x3, zero-pad) + FiLM + ReLU + → 2x rgba32uint (pack2x16float, f16 precision, half-res). + enc0: (H, W, 8) f32 — pack2x16float precision + gamma, beta: (ENC1_OUT=16,) f32 — FiLM params """ H, W = enc0.shape[:2] hH, hW = H // 2, W // 2 @@ -99,8 +100,6 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): # 3x3 conv with zero-padding at half-res borders ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant') - gamma = np.concatenate([gamma_lo, gamma_hi]) - beta = np.concatenate([beta_lo, beta_hi]) out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32) for o in range(ENC1_OUT): @@ -159,10 +158,11 @@ def bottleneck_forward(enc1, w): def dec1_forward(bn, enc1, w, gamma, beta): """ - NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU - → rgba16float (half-res). - bn: (qH, qW, 8) f32 — quarter-res bottleneck - enc1: (hH, hW, 8) f32 — half-res skip connection + NearestUp2x(bn) + cat(enc1_skip) → Conv(32->8, 3x3, zero-pad) + FiLM + ReLU + → rgba32uint (pack2x16float, half-res). + bn: (qH, qW, 16) f32 — quarter-res bottleneck + enc1: (hH, hW, 16) f32 — half-res skip connection + gamma, beta: (DEC1_OUT=8,) f32 — FiLM params """ hH, hW = enc1.shape[:2] qH, qW = bn.shape[:2] @@ -188,15 +188,15 @@ def dec1_forward(bn, enc1, w, gamma, beta): s += wv * fp[ky:ky+hH, kx:kx+hW, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) - return np.float16(out).astype(np.float32) # rgba16float boundary + return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint) def dec0_forward(dec1, enc0, w, gamma, beta): """ - NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid + NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid → rgba16float (full-res, final output). - dec1: (hH, hW, 4) f32 — half-res - enc0: (H, W, 4) f32 — full-res enc0 skip + dec1: (hH, hW, 8) f32 — half-res + enc0: (H, W, 8) f32 — full-res enc0 skip """ H, W = enc0.shape[:2] hH, hW = dec1.shape[:2] @@ -231,8 +231,7 @@ def forward_pass(feat0, feat1, w_f32, film): enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, - film['enc1_gamma_lo'], film['enc1_gamma_hi'], - film['enc1_beta_lo'], film['enc1_beta_hi']) + film['enc1_gamma'], film['enc1_beta']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) dc0 = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) @@ -241,16 +240,14 @@ def forward_pass(feat0, feat1, w_f32, film): def identity_film(): return { - 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), - 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), - 'enc1_gamma_lo': np.ones(4, dtype=np.float32), - 'enc1_gamma_hi': np.ones(4, dtype=np.float32), - 'enc1_beta_lo': np.zeros(4, dtype=np.float32), - 'enc1_beta_hi': np.zeros(4, dtype=np.float32), - 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), - 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), - 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), - 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), + 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), # 8 + 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), # 8 + 'enc1_gamma': np.ones(ENC1_OUT, dtype=np.float32), # 16 + 'enc1_beta': np.zeros(ENC1_OUT, dtype=np.float32), # 16 + 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), # 8 + 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), # 8 + 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), # 4 + 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), # 4 } @@ -324,8 +321,7 @@ def generate_vectors(W=8, H=8, seed=42): enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, - film['enc1_gamma_lo'], film['enc1_gamma_hi'], - film['enc1_beta_lo'], film['enc1_beta_hi']) + film['enc1_gamma'], film['enc1_beta']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) out = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) @@ -333,8 +329,9 @@ def generate_vectors(W=8, H=8, seed=42): feat0_u32 = pack_feat0_rgba32uint(feat0, H, W) feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W) w_u32 = pack_weights_u32(w_f16) + # enc0: 8ch stored as pack2x16float → H*W*8 f16 values enc0_u16 = np.float16(enc0.reshape(-1)).view(np.uint16) - # dec1 is half-res (hH x hW x 4); store as-is + # dec1: 8ch half-res stored as pack2x16float → (H/2)*(W/2)*8 f16 values dc1_u16 = np.float16(dc1.reshape(-1)).view(np.uint16) out_u16 = np.float16(out.reshape(-1)).view(np.uint16) # raw f16 bits @@ -386,11 +383,15 @@ def emit_c_header(v): lines.append("};") lines.append("") + lines.append(f"// ENC0_OUT={ENC0_OUT} ENC1_OUT={ENC1_OUT} BN={BN_OUT} DEC1_OUT={DEC1_OUT} DEC0_OUT={DEC0_OUT}") + lines.append(f"// TOTAL_F16={TOTAL_F16} (enc_channels=[{ENC0_OUT},{ENC1_OUT}])") + lines.append("") array_u32("kCnnV3TestFeat0U32", v['feat0_u32']) array_u32("kCnnV3TestFeat1U32", v['feat1_u32']) array_u32("kCnnV3TestWeightsU32", v['w_u32']) + lines.append(f"// enc0: {ENC0_OUT}ch rgba32uint → W*H*{ENC0_OUT} f16 values") array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16']) - lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}") + lines.append(f"// dec1: {DEC1_OUT}ch rgba32uint half-res → (W/2)*(H/2)*{DEC1_OUT} f16 values") array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16']) array_u16("kCnnV3ExpectedOutputU16", v['out_u16']) return "\n".join(lines) diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py index ca1c72a..b0fe9e6 100644 --- a/cnn_v3/training/infer_cnn_v3.py +++ b/cnn_v3/training/infer_cnn_v3.py @@ -129,8 +129,8 @@ def main(): p.add_argument('output', help='Output PNG') p.add_argument('--checkpoint', '-c', metavar='CKPT', help='Path to .pth checkpoint (auto-finds latest if omitted)') - p.add_argument('--enc-channels', default='4,8', - help='Encoder channels (default: 4,8 — must match checkpoint)') + p.add_argument('--enc-channels', default='8,16', + help='Encoder channels (default: 8,16 — must match checkpoint)') p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5, help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)') p.add_argument('--identity-film', action='store_true', diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index c61c360..5b6a0be 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -5,18 +5,18 @@ # /// """CNN v3 Training Script — U-Net + FiLM -Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W +Architecture (enc_channels=[8,16]): + enc0 Conv(20→8, 3×3) + FiLM + ReLU H×W rgba32uint (8ch) + enc1 Conv(8→16, 3×3) + FiLM + ReLU + pool2 H/2×W/2 2× rgba32uint (16ch split) + bottleneck Conv(16→16, 3×3, dilation=2) + ReLU H/4×W/4 2× rgba32uint (16ch split) + dec1 upsample×2 + cat(enc1) Conv(32→8) + FiLM H/2×W/2 rgba32uint (8ch) + dec0 upsample×2 + cat(enc0) Conv(16→4) + FiLM H×W rgba16float (4ch) output sigmoid → RGBA -FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) - 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) +FiLM MLP: Linear(5→16) → ReLU → Linear(16→72) + 72 = 2 × (γ+β) for enc0(8) enc1(16) dec1(8) dec0(4) -Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) +Weight budget: ~15.3 KB conv f16 (7828 f16); total with MLP ~17.9 KB Training improvements: --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) @@ -47,14 +47,14 @@ def film_apply(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torc class CNNv3(nn.Module): """U-Net + FiLM conditioning. - enc_channels: [c0, c1] channel counts per encoder level, default [4, 8] + enc_channels: [c0, c1] channel counts per encoder level, default [8, 16] film_cond_dim: FiLM conditioning input size, default 5 """ def __init__(self, enc_channels=None, film_cond_dim: int = 5): super().__init__() if enc_channels is None: - enc_channels = [4, 8] + enc_channels = [8, 16] assert len(enc_channels) == 2, "Only 2-level U-Net supported" c0, c1 = enc_channels @@ -227,6 +227,10 @@ def train(args): optimizer.zero_grad() pred = model(feat, cond) loss = criterion(pred, target) + if args.multiscale_weight > 0.0: + for scale in [2, 4]: + loss = loss + args.multiscale_weight * criterion( + F.avg_pool2d(pred, scale), F.avg_pool2d(target, scale)) if args.edge_loss_weight > 0.0: loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() @@ -321,6 +325,8 @@ def main(): help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') p.add_argument('--edge-loss-weight', type=float, default=0.1, help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--multiscale-weight', type=float, default=0.5, + help='Weight per pyramid level for multi-scale MSE (default 0.5; 0=disable)') p.add_argument('--film-warmup-epochs', type=int, default=50, help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') |
