summaryrefslogtreecommitdiff
path: root/cnn_v3/training/gen_test_vectors.py
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/training/gen_test_vectors.py')
-rw-r--r--cnn_v3/training/gen_test_vectors.py91
1 files changed, 46 insertions, 45 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 2eb889c..cdda5a5 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -15,17 +15,17 @@ import argparse
# Weight layout (f16 units, matching C++ cnn_v3_effect.cc constants)
# ---------------------------------------------------------------------------
-ENC0_IN, ENC0_OUT = 20, 4
-ENC1_IN, ENC1_OUT = 4, 8
-BN_IN, BN_OUT = 8, 8
-DEC1_IN, DEC1_OUT = 16, 4
-DEC0_IN, DEC0_OUT = 8, 4
+ENC0_IN, ENC0_OUT = 20, 8
+ENC1_IN, ENC1_OUT = 8, 16
+BN_IN, BN_OUT = 16, 16
+DEC1_IN, DEC1_OUT = 32, 8
+DEC0_IN, DEC0_OUT = 16, 4
-ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724
-ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296
-BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2)
-DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580
-DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292
+ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 1448
+ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 1168
+BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 2320 (3x3 dilation=2)
+DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 2312
+DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 580
ENC0_OFFSET = 0
ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
@@ -33,7 +33,7 @@ BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS
-# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72)
+# 1448 + 1168 + 2320 + 2312 + 580 = 7828
# ---------------------------------------------------------------------------
# Helpers
@@ -50,11 +50,11 @@ def get_w(w_f32, base, idx):
def enc0_forward(feat0, feat1, w, gamma, beta):
"""
- Conv(20->4, 3x3, zero-pad) + FiLM + ReLU → rgba16float (f16 stored).
+ Conv(20->8, 3x3, zero-pad) + FiLM + ReLU → rgba32uint (pack2x16float, f16 stored).
feat0: (H, W, 8) f32 — channels from unpack2x16float(feat_tex0)
feat1: (H, W, 12) f32 — channels from unpack4x8unorm(feat_tex1)
- gamma, beta: (ENC0_OUT,) f32 — FiLM params
- Returns: (H, W, 4) f32 — f16 precision (rgba16float texture boundary)
+ gamma, beta: (ENC0_OUT=8,) f32 — FiLM params
+ Returns: (H, W, 8) f32 — f16 precision (pack2x16float boundary)
"""
H, W = feat0.shape[:2]
wo = ENC0_OFFSET
@@ -72,14 +72,15 @@ def enc0_forward(feat0, feat1, w, gamma, beta):
s += wv * fp[ky:ky+H, kx:kx+W, i]
out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])
- return np.float16(out).astype(np.float32) # rgba16float texture boundary
+ return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint)
-def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
+def enc1_forward(enc0, w, gamma, beta):
"""
- AvgPool2x2(enc0, clamp-border) + Conv(4->8, 3x3, zero-pad) + FiLM + ReLU
- → rgba32uint (pack2x16float, f16 precision, half-res).
- enc0: (H, W, 4) f32 — rgba16float precision
+ AvgPool2x2(enc0, clamp-border) + Conv(8->16, 3x3, zero-pad) + FiLM + ReLU
+ → 2x rgba32uint (pack2x16float, f16 precision, half-res).
+ enc0: (H, W, 8) f32 — pack2x16float precision
+ gamma, beta: (ENC1_OUT=16,) f32 — FiLM params
"""
H, W = enc0.shape[:2]
hH, hW = H // 2, W // 2
@@ -99,8 +100,6 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
# 3x3 conv with zero-padding at half-res borders
ap = np.pad(avg, ((1, 1), (1, 1), (0, 0)), mode='constant')
- gamma = np.concatenate([gamma_lo, gamma_hi])
- beta = np.concatenate([beta_lo, beta_hi])
out = np.zeros((hH, hW, ENC1_OUT), dtype=np.float32)
for o in range(ENC1_OUT):
@@ -159,10 +158,11 @@ def bottleneck_forward(enc1, w):
def dec1_forward(bn, enc1, w, gamma, beta):
"""
- NearestUp2x(bn) + cat(enc1_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU
- → rgba16float (half-res).
- bn: (qH, qW, 8) f32 — quarter-res bottleneck
- enc1: (hH, hW, 8) f32 — half-res skip connection
+ NearestUp2x(bn) + cat(enc1_skip) → Conv(32->8, 3x3, zero-pad) + FiLM + ReLU
+ → rgba32uint (pack2x16float, half-res).
+ bn: (qH, qW, 16) f32 — quarter-res bottleneck
+ enc1: (hH, hW, 16) f32 — half-res skip connection
+ gamma, beta: (DEC1_OUT=8,) f32 — FiLM params
"""
hH, hW = enc1.shape[:2]
qH, qW = bn.shape[:2]
@@ -188,15 +188,15 @@ def dec1_forward(bn, enc1, w, gamma, beta):
s += wv * fp[ky:ky+hH, kx:kx+hW, i]
out[:, :, o] = np.maximum(0.0, gamma[o] * s + beta[o])
- return np.float16(out).astype(np.float32) # rgba16float boundary
+ return np.float16(out).astype(np.float32) # pack2x16float boundary (rgba32uint)
def dec0_forward(dec1, enc0, w, gamma, beta):
"""
- NearestUp2x(dec1) + cat(enc0_skip) → Conv(8->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
+ NearestUp2x(dec1) + cat(enc0_skip) → Conv(16->4, 3x3, zero-pad) + FiLM + ReLU + sigmoid
→ rgba16float (full-res, final output).
- dec1: (hH, hW, 4) f32 — half-res
- enc0: (H, W, 4) f32 — full-res enc0 skip
+ dec1: (hH, hW, 8) f32 — half-res
+ enc0: (H, W, 8) f32 — full-res enc0 skip
"""
H, W = enc0.shape[:2]
hH, hW = dec1.shape[:2]
@@ -231,8 +231,7 @@ def forward_pass(feat0, feat1, w_f32, film):
enc0 = enc0_forward(feat0, feat1, w_f32,
film['enc0_gamma'], film['enc0_beta'])
enc1 = enc1_forward(enc0, w_f32,
- film['enc1_gamma_lo'], film['enc1_gamma_hi'],
- film['enc1_beta_lo'], film['enc1_beta_hi'])
+ film['enc1_gamma'], film['enc1_beta'])
bn = bottleneck_forward(enc1, w_f32)
dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
dc0 = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])
@@ -241,16 +240,14 @@ def forward_pass(feat0, feat1, w_f32, film):
def identity_film():
return {
- 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32),
- 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32),
- 'enc1_gamma_lo': np.ones(4, dtype=np.float32),
- 'enc1_gamma_hi': np.ones(4, dtype=np.float32),
- 'enc1_beta_lo': np.zeros(4, dtype=np.float32),
- 'enc1_beta_hi': np.zeros(4, dtype=np.float32),
- 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32),
- 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32),
- 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32),
- 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32),
+ 'enc0_gamma': np.ones(ENC0_OUT, dtype=np.float32), # 8
+ 'enc0_beta': np.zeros(ENC0_OUT, dtype=np.float32), # 8
+ 'enc1_gamma': np.ones(ENC1_OUT, dtype=np.float32), # 16
+ 'enc1_beta': np.zeros(ENC1_OUT, dtype=np.float32), # 16
+ 'dec1_gamma': np.ones(DEC1_OUT, dtype=np.float32), # 8
+ 'dec1_beta': np.zeros(DEC1_OUT, dtype=np.float32), # 8
+ 'dec0_gamma': np.ones(DEC0_OUT, dtype=np.float32), # 4
+ 'dec0_beta': np.zeros(DEC0_OUT, dtype=np.float32), # 4
}
@@ -324,8 +321,7 @@ def generate_vectors(W=8, H=8, seed=42):
enc0 = enc0_forward(feat0, feat1, w_f32,
film['enc0_gamma'], film['enc0_beta'])
enc1 = enc1_forward(enc0, w_f32,
- film['enc1_gamma_lo'], film['enc1_gamma_hi'],
- film['enc1_beta_lo'], film['enc1_beta_hi'])
+ film['enc1_gamma'], film['enc1_beta'])
bn = bottleneck_forward(enc1, w_f32)
dc1 = dec1_forward(bn, enc1, w_f32, film['dec1_gamma'], film['dec1_beta'])
out = dec0_forward(dc1, enc0, w_f32, film['dec0_gamma'], film['dec0_beta'])
@@ -333,8 +329,9 @@ def generate_vectors(W=8, H=8, seed=42):
feat0_u32 = pack_feat0_rgba32uint(feat0, H, W)
feat1_u32 = pack_feat1_rgba32uint(feat1_u8, H, W)
w_u32 = pack_weights_u32(w_f16)
+ # enc0: 8ch stored as pack2x16float → H*W*8 f16 values
enc0_u16 = np.float16(enc0.reshape(-1)).view(np.uint16)
- # dec1 is half-res (hH x hW x 4); store as-is
+ # dec1: 8ch half-res stored as pack2x16float → (H/2)*(W/2)*8 f16 values
dc1_u16 = np.float16(dc1.reshape(-1)).view(np.uint16)
out_u16 = np.float16(out.reshape(-1)).view(np.uint16) # raw f16 bits
@@ -386,11 +383,15 @@ def emit_c_header(v):
lines.append("};")
lines.append("")
+ lines.append(f"// ENC0_OUT={ENC0_OUT} ENC1_OUT={ENC1_OUT} BN={BN_OUT} DEC1_OUT={DEC1_OUT} DEC0_OUT={DEC0_OUT}")
+ lines.append(f"// TOTAL_F16={TOTAL_F16} (enc_channels=[{ENC0_OUT},{ENC1_OUT}])")
+ lines.append("")
array_u32("kCnnV3TestFeat0U32", v['feat0_u32'])
array_u32("kCnnV3TestFeat1U32", v['feat1_u32'])
array_u32("kCnnV3TestWeightsU32", v['w_u32'])
+ lines.append(f"// enc0: {ENC0_OUT}ch rgba32uint → W*H*{ENC0_OUT} f16 values")
array_u16("kCnnV3ExpectedEnc0U16", v['enc0_u16'])
- lines.append(f"// kCnnV3Dec1HW = (W/2) x (H/2) = {v['W']//2} x {v['H']//2}")
+ lines.append(f"// dec1: {DEC1_OUT}ch rgba32uint half-res → (W/2)*(H/2)*{DEC1_OUT} f16 values")
array_u16("kCnnV3ExpectedDec1U16", v['dc1_u16'])
array_u16("kCnnV3ExpectedOutputU16", v['out_u16'])
return "\n".join(lines)