#!/usr/bin/env python3 # CNN v3 parity reference — numpy forward pass matching WGSL shaders exactly. # Generates test vectors for C++ GPU parity validation. # # Usage: # python3 cnn_v3/training/gen_test_vectors.py # self-test only # python3 cnn_v3/training/gen_test_vectors.py --header # emit C header to stdout import numpy as np import struct import sys import argparse # --------------------------------------------------------------------------- # Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants) # --------------------------------------------------------------------------- ENC0_IN, ENC0_OUT = 20, 4 ENC1_IN, ENC1_OUT = 4, 8 BN_IN, BN_OUT = 8, 8 DEC1_IN, DEC1_OUT = 16, 4 DEC0_IN, DEC0_OUT = 8, 4 ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72 DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 ENC0_OFFSET = 0 ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check # 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount # ENC0: 20*4*9 + 4 = 720+4 = 724 # ENC1: 4*8*9 + 8 = 288+8 = 296 # BN: 8*8*1 + 8 = 64+8 = 72 # DEC1: 16*4*9 + 4 = 576+4 = 580 # DEC0: 8*4*9 + 4 = 288+4 = 292 # Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck. # DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4 # weight idx: o * DEC1_IN * 9 + i * 9 + ki where o4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK. # Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964. # Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table: # enc0: 20*4*9=720 +4 = 724 # enc1: 4*8*9=288 +8 = 296 # bottleneck: 8*8*1=64 +8 = 72 # dec1: 16*4*9=576 +4 = 580 # dec0: 8*4*9=288 +4 = 292 # Total = 724+296+72+580+292 = 1964 # The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc. # I'll use the correct value derived from the formulas: 1964. # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def get_w(w_f32, base, idx): """Read one f16-precision weight. Matches WGSL get_w().""" return float(w_f32[base + idx]) # --------------------------------------------------------------------------- # Layer forward passes — each matches the corresponding WGSL compute shader # --------------------------------------------------------------------------- def enc0_forward(feat0, feat1, w, gamma, beta): """ Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored). feat0: (H, W, 8) f32 — channels from unpack2x16float(feat_tex0) feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1) gamma, beta: (ENC0_OUT,) f32 — FiLM params Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary) """ H, W = feat0.shape[:2] wo = ENC0_OFFSET feat = np.concatenate([feat0, feat1], axis=2) # (H, W, 20) fp = np.pad(feat, ((1, 1), (1, 1), (0, 0)), mode='constant') # zero-pad out = np.zeros((H, W, ENC0_OUT), dtype=np.float32) for o in range(ENC0_OUT): bias = get_w(w, wo, ENC0_OUT * ENC0_IN * 9 + o) s = np.full((H, W), bias, dtype=np.float32) for i in range(ENC0_IN): for ky in range(3): for kx in range(3): wv = get_w(w, wo, o * ENC0_IN * 9 + i * 9 + ky * 3 + kx) s += wv * fp[ky:ky+H, kx:kx+W, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) return np.float16(out).astype(np.float32) # rgba16float texture boundary def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): """ AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU → rgba32uint (pack2x16float, f16 precision, half-res). enc0: (H, W, 4) f32 — rgba16float precision """ H, W = enc0.shape[:2] hH, hW = H // 2, W // 2 wo = ENC1_OFFSET # AvgPool2x2 with clamp at borders (matches load_enc0_avg in WGSL) avg = np.zeros((hH, hW, ENC1_IN), dtype=np.float32) for hy in range(hH): for hx in range(hW): s = np.zeros(ENC1_IN, dtype=np.float32) for dy in range(2): for dx in range(2): fy = min(hy * 2 + dy, H - 1) fx = min(hx * 2 + dx, W - 1) s += enc0[fy, fx, :] avg[hy, hx, :] = s * 0.25 # 3x3 conv with zero-padding at half-res borders ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant') gamma = np.concatenate([gamma_lo, gamma_hi]) beta = np.concatenate([beta_lo, beta_hi]) out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32) for o in range(ENC1_OUT): bias = get_w(w, wo, ENC1_OUT * ENC1_IN * 9 + o) s = np.full((hH, hW), bias, dtype=np.float32) for i in range(ENC1_IN): for ky in range(3): for kx in range(3): wv = get_w(w, wo, o * ENC1_IN * 9 + i * 9 + ky * 3 + kx) s += wv * ap[ky:ky+hH, kx:kx+hW, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) return np.float16(out).astype(np.float32) # pack2x16float boundary def bottleneck_forward(enc1, w): """ AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU → rgba32uint (f16, quarter-res). No FiLM. enc1: (hH, hW, 8) f32 — half-res """ hH, hW = enc1.shape[:2] qH, qW = hH // 2, hW // 2 wo = BN_OFFSET # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL) avg = np.zeros((qH, qW, BN_IN), dtype=np.float32) for qy in range(qH): for qx in range(qW): s = np.zeros(BN_IN, dtype=np.float32) for dy in range(2): for dx in range(2): hy = min(qy * 2 + dy, hH - 1) hx = min(qx * 2 + dx, hW - 1) s += enc1[hy, hx, :] avg[qy, qx, :] = s * 0.25 # 1x1 conv (no spatial loop, just channel dot-product) out = np.zeros((qH, qW, BN_OUT), dtype=np.float32) for o in range(BN_OUT): bias = get_w(w, wo, BN_OUT * BN_IN + o) s = np.full((qH, qW), bias, dtype=np.float32) for i in range(BN_IN): wv = get_w(w, wo, o * BN_IN + i) s += wv * avg[:, :, i] out[:, :, o] = np.maximum(0.0, s) return np.float16(out).astype(np.float32) # pack2x16float boundary def dec1_forward(bn, enc1, w, gamma, beta): """ NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (half-res). bn: (qH, qW, 8) f32 — quarter-res bottleneck enc1: (hH, hW, 8) f32 — half-res skip connection """ hH, hW = enc1.shape[:2] qH, qW = bn.shape[:2] wo = DEC1_OFFSET # Build 16-channel input: [nearest_up(bn), enc1_skip], zero-padded for 3x3 # load_dec1_concat: if OOB → zeros; otherwise nearest_up + enc1 fp = np.zeros((hH + 2, hW + 2, DEC1_IN), dtype=np.float32) for hy in range(hH): for hx in range(hW): qy = min(hy // 2, qH - 1) qx = min(hx // 2, qW - 1) fp[hy + 1, hx + 1, :] = np.concatenate([bn[qy, qx, :], enc1[hy, hx, :]]) out = np.zeros((hH, hW, DEC1_OUT), dtype=np.float32) for o in range(DEC1_OUT): bias = get_w(w, wo, DEC1_OUT * DEC1_IN * 9 + o) s = np.full((hH, hW), bias, dtype=np.float32) for i in range(DEC1_IN): for ky in range(3): for kx in range(3): wv = get_w(w, wo, o * DEC1_IN * 9 + i * 9 + ky * 3 + kx) s += wv * fp[ky:ky+hH, kx:kx+hW, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) return np.float16(out).astype(np.float32) # rgba16float boundary def dec0_forward(dec1, enc0, w, gamma, beta): """ NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid → rgba16float (full-res, final output). dec1: (hH, hW, 4) f32 — half-res enc0: (H, W, 4) f32 — full-res enc0 skip """ H, W = enc0.shape[:2] hH, hW = dec1.shape[:2] wo = DEC0_OFFSET # Build 8-channel input: [nearest_up(dec1), enc0_skip], zero-padded fp = np.zeros((H + 2, W + 2, DEC0_IN), dtype=np.float32) for y in range(H): for x in range(W): hy = min(y // 2, hH - 1) hx = min(x // 2, hW - 1) fp[y + 1, x + 1, :] = np.concatenate([dec1[hy, hx, :], enc0[y, x, :]]) out = np.zeros((H, W, DEC0_OUT), dtype=np.float32) for o in range(DEC0_OUT): bias = get_w(w, wo, DEC0_OUT * DEC0_IN * 9 + o) s = np.full((H, W), bias, dtype=np.float32) for i in range(DEC0_IN): for ky in range(3): for kx in range(3): wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx) s += wv * fp[ky:ky+H, kx:kx+W, i] # FiLM + ReLU + sigmoid (matches WGSL dec0 shader) v = np.maximum(0.0, gamma[o] * s + beta[o]) out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32) return np.float16(out).astype(np.float32) # rgba16float boundary def forward_pass(feat0, feat1, w_f32, film): """Full U-Net forward pass. film is a dict of gamma/beta arrays.""" enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, film['enc1_gamma_lo'], film['enc1_gamma_hi'], film['enc1_beta_lo'], film['enc1_beta_hi']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) dc0 = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) return dc0 def identity_film(): return { 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), 'enc1_gamma_lo': np.ones(4, dtype=np.float32), 'enc1_gamma_hi': np.ones(4, dtype=np.float32), 'enc1_beta_lo': np.zeros(4, dtype=np.float32), 'enc1_beta_hi': np.zeros(4, dtype=np.float32), 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), } # --------------------------------------------------------------------------- # Self-test: zero weights → output must be exactly 0.5 # --------------------------------------------------------------------------- def test_zero_weights(): H, W = 8, 8 w = np.zeros(TOTAL_F16, dtype=np.float32) feat0 = np.zeros((H, W, 8), dtype=np.float32) feat1 = np.zeros((H, W, 12), dtype=np.float32) out = forward_pass(feat0, feat1, w, identity_film()) max_err = float(np.max(np.abs(out - 0.5))) ok = max_err < 1e-5 print(f"[test_zero_weights] max_err={max_err:.2e} {'OK' if ok else 'FAIL'}", file=sys.stderr) return ok # --------------------------------------------------------------------------- # Test vector generation and C header emission # --------------------------------------------------------------------------- def pack_feat0_rgba32uint(feat0_f32, H, W): """Pack (H*W, 8) f16-precision values as H*W*4 u32 (pack2x16float layout).""" f16 = np.float16(feat0_f32.reshape(H * W, 8)) u16 = f16.view(np.uint16) # (H*W, 8) u16 u32 = np.zeros((H * W, 4), dtype=np.uint32) for j in range(4): u32[:, j] = u16[:, j*2].astype(np.uint32) | (u16[:, j*2+1].astype(np.uint32) << 16) return u32.flatten() # H*W*4 u32 def pack_feat1_rgba32uint(feat1_u8, H, W): """Pack (H*W, 12) u8 values as H*W*4 u32 (pack4x8unorm, 4th u32 = 0).""" u8 = feat1_u8.reshape(H * W, 12) u32 = np.zeros((H * W, 4), dtype=np.uint32) for j in range(3): for b in range(4): u32[:, j] |= u8[:, j*4+b].astype(np.uint32) << (b * 8) return u32.flatten() # H*W*4 u32 def pack_weights_u32(w_f16): """Pack flat f16 array as u32 pairs matching WGSL get_w() layout.""" # Pad to even count if len(w_f16) % 2: w_f16 = np.append(w_f16, np.float16(0)) u16 = w_f16.view(np.uint16) u32 = u16[::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16) return u32 def generate_vectors(W=8, H=8, seed=42): rng = np.random.default_rng(seed) # Random f16 weights (small range to avoid NaN/Inf cascading) w_f16 = rng.uniform(-0.3, 0.3, TOTAL_F16).astype(np.float16) w_f32 = w_f16.astype(np.float32) # Random feat0: 8 f16-precision channels feat0_f16 = rng.uniform(0.0, 1.0, (H, W, 8)).astype(np.float16) feat0 = feat0_f16.astype(np.float32) # Random feat1: 12 u8 channels (unpacked as unorm [0,1]) feat1_u8 = rng.integers(0, 256, (H, W, 12), dtype=np.uint8) feat1 = feat1_u8.astype(np.float32) / 255.0 film = identity_film() enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, film['enc1_gamma_lo'], film['enc1_gamma_hi'], film['enc1_beta_lo'], film['enc1_beta_hi']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) out = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) feat0_u32 = pack_feat0_rgba32uint(feat0, H, W) feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W) w_u32 = pack_weights_u32(w_f16) enc0_u16 = np.float16(enc0.reshape(-1)).view(np.uint16) # dec1 is half-res (hH x hW x 4); store as-is dc1_u16 = np.float16(dc1.reshape(-1)).view(np.uint16) out_u16 = np.float16(out.reshape(-1)).view(np.uint16) # raw f16 bits return { 'W': W, 'H': H, 'seed': seed, 'feat0_u32': feat0_u32, 'feat1_u32': feat1_u32, 'w_u32': w_u32, 'enc0_u16': enc0_u16, 'dc1_u16': dc1_u16, 'out_u16': out_u16, 'out_f32': out.reshape(-1), } def emit_c_header(v): lines = [] lines.append("// Auto-generated by cnn_v3/training/gen_test_vectors.py") lines.append(f"// Seed={v['seed']} W={v['W']} H={v['H']}") lines.append("// DO NOT EDIT — regenerate with gen_test_vectors.py --header") lines.append("#pragma once") lines.append("#include ") lines.append("") lines.append(f"static const int kCnnV3TestW = {v['W']};") lines.append(f"static const int kCnnV3TestH = {v['H']};") lines.append("") def array_u32(name, data): lines.append(f"// {len(data)} u32 values") lines.append(f"static const uint32_t {name}[{len(data)}] = {{") row = [] for i, x in enumerate(data): row.append(f"0x{int(x):08x}u") if len(row) == 8 or i == len(data) - 1: lines.append(" " + ", ".join(row) + ",") row = [] lines.append("};") lines.append("") def array_u16(name, data): lines.append(f"// {len(data)} uint16 values (raw f16 bits)") lines.append(f"static const uint16_t {name}[{len(data)}] = {{") row = [] for i, x in enumerate(data): row.append(f"0x{int(x):04x}u") if len(row) == 8 or i == len(data) - 1: lines.append(" " + ", ".join(row) + ",") row = [] lines.append("};") lines.append("") array_u32("kCnnV3TestFeat0U32", v['feat0_u32']) array_u32("kCnnV3TestFeat1U32", v['feat1_u32']) array_u32("kCnnV3TestWeightsU32", v['w_u32']) array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16']) lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}") array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16']) array_u16("kCnnV3ExpectedOutputU16", v['out_u16']) return "\n".join(lines) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="CNN v3 parity test vector generator") parser.add_argument('--header', action='store_true', help='Emit C header to stdout') parser.add_argument('--W', type=int, default=8) parser.add_argument('--H', type=int, default=8) parser.add_argument('--seed', type=int, default=42) args = parser.parse_args() # Send self-test output to stderr so --header stdout stays clean import io log = sys.stderr if args.header else sys.stdout ok = test_zero_weights() if not ok: sys.exit(1) if args.header: v = generate_vectors(args.W, args.H, args.seed) print(emit_c_header(v)) # C header → stdout only print("All checks passed.", file=log) else: v = generate_vectors(args.W, args.H, args.seed) out = v['out_f32'] print(f"[gen_test_vectors] W={args.W} H={args.H} seed={args.seed}") print(f" output range: [{float(out.min()):.4f}, {float(out.max()):.4f}]") print(f" output mean: {float(out.mean()):.4f}") print(" Run with --header to emit C header for C++ parity test.") print("All checks passed.") if __name__ == '__main__': main()