diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
| commit | 8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch) | |
| tree | 2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/training/gen_test_vectors.py | |
| parent | 4ca498277b033ae10134045dae9c8c249a8d2b2b (diff) | |
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch,
dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB).
FiLM MLP output 40→72 params (L1: 16×40→16×72).
16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck).
enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch).
GBUF_RGBA32UINT node gains CopySrc for parity test readback.
- WGSL shaders: all 5 passes rewritten for new channel counts
- C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs
- Web tool (shaders.js + tester.js): matching texture formats and bindings
- Parity test: readback_rgba32uint_8ch helper, updated vector counts
- Training scripts: default enc_channels=[8,16], updated docstrings
- Docs + architecture PNG regenerated
handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web
tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/training/gen_test_vectors.py')
| -rw-r--r-- | cnn_v3/training/gen_test_vectors.py | 91 |
1 files changed, 46 insertions, 45 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 2eb889c..cdda5a5 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -15,17 +15,17 @@ import argparse # Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants) # --------------------------------------------------------------------------- -ENC0_IN, ENC0_OUT = 20, 4 -ENC1_IN, ENC1_OUT = 4, 8 -BN_IN, BN_OUT = 8, 8 -DEC1_IN, DEC1_OUT = 16, 4 -DEC0_IN, DEC0_OUT = 8, 4 +ENC0_IN, ENC0_OUT = 20, 8 +ENC1_IN, ENC1_OUT = 8, 16 +BN_IN, BN_OUT = 16, 16 +DEC1_IN, DEC1_OUT = 32, 8 +DEC0_IN, DEC0_OUT = 16, 4 -ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 -ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 -BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2) -DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 -DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 +ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 1448 +ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 1168 +BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 2320 (3x3 dilation=2) +DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 2312 +DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 580 ENC0_OFFSET = 0 ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS @@ -33,7 +33,7 @@ BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS -# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72) +# 1448 + 1168 + 2320 + 2312 + 580 = 7828 # --------------------------------------------------------------------------- # Helpers @@ -50,11 +50,11 @@ def get_w(w_f32, base, idx): def enc0_forward(feat0, feat1, w, gamma, beta): """ - Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored). + Conv(20->8, 3x3, zero-pad) + FiLM + ReLU → rgba32uint (pack2x16float, f16 stored). feat0: (H, W, 8) f32 — channels from unpack2x16float(feat_tex0) feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1) - gamma, beta: (ENC0_OUT,) f32 — FiLM params - Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary) + gamma, beta: (ENC0_OUT=8,) f32 — FiLM params + Returns: (H, W, 8) f32 — f16 precision (pack2x16float boundary) """ H, W = feat0.shape[:2] wo = ENC0_OFFSET @@ -72,14 +72,15 @@ def enc0_forward(feat0, feat1, w, gamma, beta): s += wv * fp[ky:ky+H, kx:kx+W, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) - return np.float16(out).astype(np.float32) # rgba16float texture boundary + return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint) -def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): +def enc1_forward(enc0, w, gamma, beta): """ - AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU - → rgba32uint (pack2x16float, f16 precision, half-res). - enc0: (H, W, 4) f32 — rgba16float precision + AvgPool2x2(enc0, clamp-border) + Conv(8->16, 3x3, zero-pad) + FiLM + ReLU + → 2x rgba32uint (pack2x16float, f16 precision, half-res). + enc0: (H, W, 8) f32 — pack2x16float precision + gamma, beta: (ENC1_OUT=16,) f32 — FiLM params """ H, W = enc0.shape[:2] hH, hW = H // 2, W // 2 @@ -99,8 +100,6 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): # 3x3 conv with zero-padding at half-res borders ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant') - gamma = np.concatenate([gamma_lo, gamma_hi]) - beta = np.concatenate([beta_lo, beta_hi]) out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32) for o in range(ENC1_OUT): @@ -159,10 +158,11 @@ def bottleneck_forward(enc1, w): def dec1_forward(bn, enc1, w, gamma, beta): """ - NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU - → rgba16float (half-res). - bn: (qH, qW, 8) f32 — quarter-res bottleneck - enc1: (hH, hW, 8) f32 — half-res skip connection + NearestUp2x(bn) + cat(enc1_skip) → Conv(32->8, 3x3, zero-pad) + FiLM + ReLU + → rgba32uint (pack2x16float, half-res). + bn: (qH, qW, 16) f32 — quarter-res bottleneck + enc1: (hH, hW, 16) f32 — half-res skip connection + gamma, beta: (DEC1_OUT=8,) f32 — FiLM params """ hH, hW = enc1.shape[:2] qH, qW = bn.shape[:2] @@ -188,15 +188,15 @@ def dec1_forward(bn, enc1, w, gamma, beta): s += wv * fp[ky:ky+hH, kx:kx+hW, i] out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o]) - return np.float16(out).astype(np.float32) # rgba16float boundary + return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint) def dec0_forward(dec1, enc0, w, gamma, beta): """ - NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid + NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid → rgba16float (full-res, final output). - dec1: (hH, hW, 4) f32 — half-res - enc0: (H, W, 4) f32 — full-res enc0 skip + dec1: (hH, hW, 8) f32 — half-res + enc0: (H, W, 8) f32 — full-res enc0 skip """ H, W = enc0.shape[:2] hH, hW = dec1.shape[:2] @@ -231,8 +231,7 @@ def forward_pass(feat0, feat1, w_f32, film): enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, - film['enc1_gamma_lo'], film['enc1_gamma_hi'], - film['enc1_beta_lo'], film['enc1_beta_hi']) + film['enc1_gamma'], film['enc1_beta']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) dc0 = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) @@ -241,16 +240,14 @@ def forward_pass(feat0, feat1, w_f32, film): def identity_film(): return { - 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), - 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), - 'enc1_gamma_lo': np.ones(4, dtype=np.float32), - 'enc1_gamma_hi': np.ones(4, dtype=np.float32), - 'enc1_beta_lo': np.zeros(4, dtype=np.float32), - 'enc1_beta_hi': np.zeros(4, dtype=np.float32), - 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), - 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), - 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), - 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), + 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), # 8 + 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), # 8 + 'enc1_gamma': np.ones(ENC1_OUT, dtype=np.float32), # 16 + 'enc1_beta': np.zeros(ENC1_OUT, dtype=np.float32), # 16 + 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), # 8 + 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), # 8 + 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), # 4 + 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), # 4 } @@ -324,8 +321,7 @@ def generate_vectors(W=8, H=8, seed=42): enc0 = enc0_forward(feat0, feat1, w_f32, film['enc0_gamma'], film['enc0_beta']) enc1 = enc1_forward(enc0, w_f32, - film['enc1_gamma_lo'], film['enc1_gamma_hi'], - film['enc1_beta_lo'], film['enc1_beta_hi']) + film['enc1_gamma'], film['enc1_beta']) bn = bottleneck_forward(enc1, w_f32) dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta']) out = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta']) @@ -333,8 +329,9 @@ def generate_vectors(W=8, H=8, seed=42): feat0_u32 = pack_feat0_rgba32uint(feat0, H, W) feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W) w_u32 = pack_weights_u32(w_f16) + # enc0: 8ch stored as pack2x16float → H*W*8 f16 values enc0_u16 = np.float16(enc0.reshape(-1)).view(np.uint16) - # dec1 is half-res (hH x hW x 4); store as-is + # dec1: 8ch half-res stored as pack2x16float → (H/2)*(W/2)*8 f16 values dc1_u16 = np.float16(dc1.reshape(-1)).view(np.uint16) out_u16 = np.float16(out.reshape(-1)).view(np.uint16) # raw f16 bits @@ -386,11 +383,15 @@ def emit_c_header(v): lines.append("};") lines.append("") + lines.append(f"// ENC0_OUT={ENC0_OUT} ENC1_OUT={ENC1_OUT} BN={BN_OUT} DEC1_OUT={DEC1_OUT} DEC0_OUT={DEC0_OUT}") + lines.append(f"// TOTAL_F16={TOTAL_F16} (enc_channels=[{ENC0_OUT},{ENC1_OUT}])") + lines.append("") array_u32("kCnnV3TestFeat0U32", v['feat0_u32']) array_u32("kCnnV3TestFeat1U32", v['feat1_u32']) array_u32("kCnnV3TestWeightsU32", v['w_u32']) + lines.append(f"// enc0: {ENC0_OUT}ch rgba32uint → W*H*{ENC0_OUT} f16 values") array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16']) - lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}") + lines.append(f"// dec1: {DEC1_OUT}ch rgba32uint half-res → (W/2)*(H/2)*{DEC1_OUT} f16 values") array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16']) array_u16("kCnnV3ExpectedOutputU16", v['out_u16']) return "\n".join(lines) |
