summaryrefslogtreecommitdiff
path: root/cnn_v3/training/gen_test_vectors.py
blob: 640971ca0c8f76538ad89ea8fed3208bb67be5cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
#!/usr/bin/env python3
# CNN v3 parity reference — numpy forward pass matching WGSL shaders exactly.
# Generates test vectors for C++ GPU parity validation.
#
# Usage:
#   python3 cnn_v3/training/gen_test_vectors.py            # self-test only
#   python3 cnn_v3/training/gen_test_vectors.py --header   # emit C header to stdout

import numpy as np
import struct
import sys
import argparse

# ---------------------------------------------------------------------------
# Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants)
# ---------------------------------------------------------------------------

ENC0_IN,  ENC0_OUT  = 20,  4
ENC1_IN,  ENC1_OUT  =  4,  8
BN_IN,    BN_OUT    =  8,  8
DEC1_IN,  DEC1_OUT  = 16,  4
DEC0_IN,  DEC0_OUT  =  8,  4

ENC0_WEIGHTS  = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT   # 724
ENC1_WEIGHTS  = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT   # 296
BN_WEIGHTS    = BN_IN   * BN_OUT   * 1 + BN_OUT     # 72
DEC1_WEIGHTS  = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT   # 580
DEC0_WEIGHTS  = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT   # 292

ENC0_OFFSET = 0
ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
BN_OFFSET   = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET   + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
TOTAL_F16   = DEC0_OFFSET + DEC0_WEIGHTS   # 1964 + 292 = 2256? let me check
# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount
# ENC0: 20*4*9 + 4 = 720+4 = 724
# ENC1: 4*8*9 + 8  = 288+8 = 296
# BN:   8*8*1 + 8  = 64+8  = 72
# DEC1: 16*4*9 + 4 = 576+4 = 580
# DEC0: 8*4*9 + 4  = 288+4 = 292
# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck.
# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4
# weight idx: o * DEC1_IN * 9 + i * 9 + ki  where o<DEC1_OUT, i<DEC1_IN
# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4
# Total DEC1 = 580. OK that's right.
# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064?
# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK.
# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964.
# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table:
#   enc0: 20*4*9=720  +4 = 724
#   enc1: 4*8*9=288   +8 = 296
#   bottleneck: 8*8*1=64 +8 = 72
#   dec1: 16*4*9=576  +4 = 580
#   dec0: 8*4*9=288   +4 = 292
#   Total = 724+296+72+580+292 = 1964
# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc.
# I'll use the correct value derived from the formulas: 1964.

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def get_w(w_f32, base, idx):
    """Read one f16-precision weight. Matches WGSL get_w()."""
    return float(w_f32[base + idx])


# ---------------------------------------------------------------------------
# Layer forward passes — each matches the corresponding WGSL compute shader
# ---------------------------------------------------------------------------

def enc0_forward(feat0, feat1, w, gamma, beta):
    """
    Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored).
    feat0: (H, W, 8)  f32 — channels from unpack2x16float(feat_tex0)
    feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1)
    gamma, beta: (ENC0_OUT,) f32 — FiLM params
    Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary)
    """
    H, W = feat0.shape[:2]
    wo = ENC0_OFFSET
    feat = np.concatenate([feat0, feat1], axis=2)   # (H, W, 20)
    fp = np.pad(feat, ((1, 1), (1, 1), (0, 0)), mode='constant')  # zero-pad

    out = np.zeros((H, W, ENC0_OUT), dtype=np.float32)
    for o in range(ENC0_OUT):
        bias = get_w(w, wo, ENC0_OUT * ENC0_IN * 9 + o)
        s = np.full((H, W), bias, dtype=np.float32)
        for i in range(ENC0_IN):
            for ky in range(3):
                for kx in range(3):
                    wv = get_w(w, wo, o * ENC0_IN * 9 + i * 9 + ky * 3 + kx)
                    s += wv * fp[ky:ky+H, kx:kx+W, i]
        out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])

    return np.float16(out).astype(np.float32)   # rgba16float texture boundary


def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
    """
    AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU
    → rgba32uint (pack2x16float, f16 precision, half-res).
    enc0: (H, W, 4) f32 — rgba16float precision
    """
    H, W = enc0.shape[:2]
    hH, hW = H // 2, W // 2
    wo = ENC1_OFFSET

    # AvgPool2x2 with clamp at borders (matches load_enc0_avg in WGSL)
    avg = np.zeros((hH, hW, ENC1_IN), dtype=np.float32)
    for hy in range(hH):
        for hx in range(hW):
            s = np.zeros(ENC1_IN, dtype=np.float32)
            for dy in range(2):
                for dx in range(2):
                    fy = min(hy * 2 + dy, H - 1)
                    fx = min(hx * 2 + dx, W - 1)
                    s += enc0[fy, fx, :]
            avg[hy, hx, :] = s * 0.25

    # 3x3 conv with zero-padding at half-res borders
    ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant')
    gamma = np.concatenate([gamma_lo, gamma_hi])
    beta  = np.concatenate([beta_lo,  beta_hi])

    out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32)
    for o in range(ENC1_OUT):
        bias = get_w(w, wo, ENC1_OUT * ENC1_IN * 9 + o)
        s = np.full((hH, hW), bias, dtype=np.float32)
        for i in range(ENC1_IN):
            for ky in range(3):
                for kx in range(3):
                    wv = get_w(w, wo, o * ENC1_IN * 9 + i * 9 + ky * 3 + kx)
                    s += wv * ap[ky:ky+hH, kx:kx+hW, i]
        out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])

    return np.float16(out).astype(np.float32)   # pack2x16float boundary


def bottleneck_forward(enc1, w):
    """
    AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU
    → rgba32uint (f16, quarter-res). No FiLM.
    enc1: (hH, hW, 8) f32 — half-res
    """
    hH, hW = enc1.shape[:2]
    qH, qW = hH // 2, hW // 2
    wo = BN_OFFSET

    # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL)
    avg = np.zeros((qH, qW, BN_IN), dtype=np.float32)
    for qy in range(qH):
        for qx in range(qW):
            s = np.zeros(BN_IN, dtype=np.float32)
            for dy in range(2):
                for dx in range(2):
                    hy = min(qy * 2 + dy, hH - 1)
                    hx = min(qx * 2 + dx, hW - 1)
                    s += enc1[hy, hx, :]
            avg[qy, qx, :] = s * 0.25

    # 1x1 conv (no spatial loop, just channel dot-product)
    out = np.zeros((qH, qW, BN_OUT), dtype=np.float32)
    for o in range(BN_OUT):
        bias = get_w(w, wo, BN_OUT * BN_IN + o)
        s = np.full((qH, qW), bias, dtype=np.float32)
        for i in range(BN_IN):
            wv = get_w(w, wo, o * BN_IN + i)
            s += wv * avg[:, :, i]
        out[:, :, o] = np.maximum(0.0, s)

    return np.float16(out).astype(np.float32)   # pack2x16float boundary


def dec1_forward(bn, enc1, w, gamma, beta):
    """
    NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU
    → rgba16float (half-res).
    bn:   (qH, qW, 8) f32 — quarter-res bottleneck
    enc1: (hH, hW, 8) f32 — half-res skip connection
    """
    hH, hW = enc1.shape[:2]
    qH, qW = bn.shape[:2]
    wo = DEC1_OFFSET

    # Build 16-channel input: [nearest_up(bn), enc1_skip], zero-padded for 3x3
    # load_dec1_concat: if OOB → zeros; otherwise nearest_up + enc1
    fp = np.zeros((hH + 2, hW + 2, DEC1_IN), dtype=np.float32)
    for hy in range(hH):
        for hx in range(hW):
            qy = min(hy // 2, qH - 1)
            qx = min(hx // 2, qW - 1)
            fp[hy + 1, hx + 1, :] = np.concatenate([bn[qy, qx, :], enc1[hy, hx, :]])

    out = np.zeros((hH, hW, DEC1_OUT), dtype=np.float32)
    for o in range(DEC1_OUT):
        bias = get_w(w, wo, DEC1_OUT * DEC1_IN * 9 + o)
        s = np.full((hH, hW), bias, dtype=np.float32)
        for i in range(DEC1_IN):
            for ky in range(3):
                for kx in range(3):
                    wv = get_w(w, wo, o * DEC1_IN * 9 + i * 9 + ky * 3 + kx)
                    s += wv * fp[ky:ky+hH, kx:kx+hW, i]
        out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])

    return np.float16(out).astype(np.float32)   # rgba16float boundary


def dec0_forward(dec1, enc0, w, gamma, beta):
    """
    NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
    → rgba16float (full-res, final output).
    dec1: (hH, hW, 4) f32 — half-res
    enc0: (H,  W,  4) f32 — full-res enc0 skip
    """
    H, W = enc0.shape[:2]
    hH, hW = dec1.shape[:2]
    wo = DEC0_OFFSET

    # Build 8-channel input: [nearest_up(dec1), enc0_skip], zero-padded
    fp = np.zeros((H + 2, W + 2, DEC0_IN), dtype=np.float32)
    for y in range(H):
        for x in range(W):
            hy = min(y // 2, hH - 1)
            hx = min(x // 2, hW - 1)
            fp[y + 1, x + 1, :] = np.concatenate([dec1[hy, hx, :], enc0[y, x, :]])

    out = np.zeros((H, W, DEC0_OUT), dtype=np.float32)
    for o in range(DEC0_OUT):
        bias = get_w(w, wo, DEC0_OUT * DEC0_IN * 9 + o)
        s = np.full((H, W), bias, dtype=np.float32)
        for i in range(DEC0_IN):
            for ky in range(3):
                for kx in range(3):
                    wv = get_w(w, wo, o * DEC0_IN * 9 + i * 9 + ky * 3 + kx)
                    s += wv * fp[ky:ky+H, kx:kx+W, i]
        # FiLM + ReLU + sigmoid (matches WGSL dec0 shader)
        v = np.maximum(0.0, gamma[o] * s + beta[o])
        out[:, :, o] = 1.0 / (1.0 + np.exp(-v.astype(np.float64))).astype(np.float32)

    return np.float16(out).astype(np.float32)   # rgba16float boundary


def forward_pass(feat0, feat1, w_f32, film):
    """Full U-Net forward pass. film is a dict of gamma/beta arrays."""
    enc0 = enc0_forward(feat0, feat1, w_f32,
                        film['enc0_gamma'], film['enc0_beta'])
    enc1 = enc1_forward(enc0, w_f32,
                        film['enc1_gamma_lo'], film['enc1_gamma_hi'],
                        film['enc1_beta_lo'],  film['enc1_beta_hi'])
    bn   = bottleneck_forward(enc1, w_f32)
    dc1  = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
    dc0  = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])
    return dc0


def identity_film():
    return {
        'enc0_gamma':    np.ones(ENC0_OUT, dtype=np.float32),
        'enc0_beta':     np.zeros(ENC0_OUT, dtype=np.float32),
        'enc1_gamma_lo': np.ones(4, dtype=np.float32),
        'enc1_gamma_hi': np.ones(4, dtype=np.float32),
        'enc1_beta_lo':  np.zeros(4, dtype=np.float32),
        'enc1_beta_hi':  np.zeros(4, dtype=np.float32),
        'dec1_gamma':    np.ones(DEC1_OUT, dtype=np.float32),
        'dec1_beta':     np.zeros(DEC1_OUT, dtype=np.float32),
        'dec0_gamma':    np.ones(DEC0_OUT, dtype=np.float32),
        'dec0_beta':     np.zeros(DEC0_OUT, dtype=np.float32),
    }


# ---------------------------------------------------------------------------
# Self-test: zero weights → output must be exactly 0.5
# ---------------------------------------------------------------------------

def test_zero_weights():
    H, W = 8, 8
    w = np.zeros(TOTAL_F16, dtype=np.float32)
    feat0 = np.zeros((H, W, 8),  dtype=np.float32)
    feat1 = np.zeros((H, W, 12), dtype=np.float32)
    out = forward_pass(feat0, feat1, w, identity_film())
    max_err = float(np.max(np.abs(out - 0.5)))
    ok = max_err < 1e-5
    print(f"[test_zero_weights] max_err={max_err:.2e}  {'OK' if ok else 'FAIL'}",
          file=sys.stderr)
    return ok


# ---------------------------------------------------------------------------
# Test vector generation and C header emission
# ---------------------------------------------------------------------------

def pack_feat0_rgba32uint(feat0_f32, H, W):
    """Pack (H*W, 8) f16-precision values as H*W*4 u32 (pack2x16float layout)."""
    f16 = np.float16(feat0_f32.reshape(H * W, 8))
    u16 = f16.view(np.uint16)   # (H*W, 8) u16
    u32 = np.zeros((H * W, 4), dtype=np.uint32)
    for j in range(4):
        u32[:, j] = u16[:, j*2].astype(np.uint32) | (u16[:, j*2+1].astype(np.uint32) << 16)
    return u32.flatten()    # H*W*4 u32


def pack_feat1_rgba32uint(feat1_u8, H, W):
    """Pack (H*W, 12) u8 values as H*W*4 u32 (pack4x8unorm, 4th u32 = 0)."""
    u8 = feat1_u8.reshape(H * W, 12)
    u32 = np.zeros((H * W, 4), dtype=np.uint32)
    for j in range(3):
        for b in range(4):
            u32[:, j] |= u8[:, j*4+b].astype(np.uint32) << (b * 8)
    return u32.flatten()    # H*W*4 u32


def pack_weights_u32(w_f16):
    """Pack flat f16 array as u32 pairs matching WGSL get_w() layout."""
    # Pad to even count
    if len(w_f16) % 2:
        w_f16 = np.append(w_f16, np.float16(0))
    u16 = w_f16.view(np.uint16)
    u32 = u16[::2].astype(np.uint32) | (u16[1::2].astype(np.uint32) << 16)
    return u32


def generate_vectors(W=8, H=8, seed=42):
    rng = np.random.default_rng(seed)

    # Random f16 weights (small range to avoid NaN/Inf cascading)
    w_f16 = rng.uniform(-0.3, 0.3, TOTAL_F16).astype(np.float16)
    w_f32 = w_f16.astype(np.float32)

    # Random feat0: 8 f16-precision channels
    feat0_f16 = rng.uniform(0.0, 1.0, (H, W, 8)).astype(np.float16)
    feat0 = feat0_f16.astype(np.float32)

    # Random feat1: 12 u8 channels (unpacked as unorm [0,1])
    feat1_u8 = rng.integers(0, 256, (H, W, 12), dtype=np.uint8)
    feat1 = feat1_u8.astype(np.float32) / 255.0

    film = identity_film()
    enc0 = enc0_forward(feat0, feat1, w_f32,
                        film['enc0_gamma'], film['enc0_beta'])
    enc1 = enc1_forward(enc0, w_f32,
                        film['enc1_gamma_lo'], film['enc1_gamma_hi'],
                        film['enc1_beta_lo'],  film['enc1_beta_hi'])
    bn   = bottleneck_forward(enc1, w_f32)
    dc1  = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
    out  = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])

    feat0_u32 = pack_feat0_rgba32uint(feat0, H, W)
    feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W)
    w_u32     = pack_weights_u32(w_f16)
    enc0_u16  = np.float16(enc0.reshape(-1)).view(np.uint16)
    # dec1 is half-res (hH x hW x 4); store as-is
    dc1_u16   = np.float16(dc1.reshape(-1)).view(np.uint16)
    out_u16   = np.float16(out.reshape(-1)).view(np.uint16)   # raw f16 bits

    return {
        'W': W, 'H': H, 'seed': seed,
        'feat0_u32':  feat0_u32,
        'feat1_u32':  feat1_u32,
        'w_u32':      w_u32,
        'enc0_u16':   enc0_u16,
        'dc1_u16':    dc1_u16,
        'out_u16':    out_u16,
        'out_f32':    out.reshape(-1),
    }


def emit_c_header(v):
    lines = []
    lines.append("// Auto-generated by cnn_v3/training/gen_test_vectors.py")
    lines.append(f"// Seed={v['seed']}  W={v['W']}  H={v['H']}")
    lines.append("// DO NOT EDIT — regenerate with gen_test_vectors.py --header")
    lines.append("#pragma once")
    lines.append("#include <cstdint>")
    lines.append("")
    lines.append(f"static const int kCnnV3TestW = {v['W']};")
    lines.append(f"static const int kCnnV3TestH = {v['H']};")
    lines.append("")

    def array_u32(name, data):
        lines.append(f"// {len(data)} u32 values")
        lines.append(f"static const uint32_t {name}[{len(data)}] = {{")
        row = []
        for i, x in enumerate(data):
            row.append(f"0x{int(x):08x}u")
            if len(row) == 8 or i == len(data) - 1:
                lines.append("    " + ", ".join(row) + ",")
                row = []
        lines.append("};")
        lines.append("")

    def array_u16(name, data):
        lines.append(f"// {len(data)} uint16 values (raw f16 bits)")
        lines.append(f"static const uint16_t {name}[{len(data)}] = {{")
        row = []
        for i, x in enumerate(data):
            row.append(f"0x{int(x):04x}u")
            if len(row) == 8 or i == len(data) - 1:
                lines.append("    " + ", ".join(row) + ",")
                row = []
        lines.append("};")
        lines.append("")

    array_u32("kCnnV3TestFeat0U32", v['feat0_u32'])
    array_u32("kCnnV3TestFeat1U32", v['feat1_u32'])
    array_u32("kCnnV3TestWeightsU32", v['w_u32'])
    array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16'])
    lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}")
    array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16'])
    array_u16("kCnnV3ExpectedOutputU16", v['out_u16'])
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="CNN v3 parity test vector generator")
    parser.add_argument('--header', action='store_true',
                        help='Emit C header to stdout')
    parser.add_argument('--W', type=int, default=8)
    parser.add_argument('--H', type=int, default=8)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    # Send self-test output to stderr so --header stdout stays clean
    import io
    log = sys.stderr if args.header else sys.stdout

    ok = test_zero_weights()
    if not ok:
        sys.exit(1)

    if args.header:
        v = generate_vectors(args.W, args.H, args.seed)
        print(emit_c_header(v))  # C header → stdout only
        print("All checks passed.", file=log)
    else:
        v = generate_vectors(args.W, args.H, args.seed)
        out = v['out_f32']
        print(f"[gen_test_vectors] W={args.W} H={args.H} seed={args.seed}")
        print(f"  output range: [{float(out.min()):.4f}, {float(out.max()):.4f}]")
        print(f"  output mean:  {float(out.mean()):.4f}")
        print("  Run with --header to emit C header for C++ parity test.")
        print("All checks passed.")


if __name__ == '__main__':
    main()