summaryrefslogtreecommitdiff
path: root/cnn_v3/training/gen_test_vectors.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 10:05:42 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 10:05:42 +0100
commitce6e5b99f26e4e7c69a3cacf360bd0d492de928c (patch)
treea8d64b33a7ea1109b6b7e1043ced946cac416756 /cnn_v3/training/gen_test_vectors.py
parent8b4d7a49f038d7e849e6764dcc3abd1e1be01061 (diff)
feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG
- Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2): effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights) - Add Sobel edge loss in training (--edge-loss-weight, default 0.1) - Add FiLM 2-phase training: freeze MLP for warmup epochs then unfreeze at lr×0.1 (--film-warmup-epochs, default 50) - Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B) - Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts - Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04 - Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py) - Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed. Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3).
Diffstat (limited to 'cnn_v3/training/gen_test_vectors.py')
-rw-r--r--cnn_v3/training/gen_test_vectors.py72
1 files changed, 28 insertions, 44 deletions
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 640971c..2eb889c 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4
ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724
ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296
-BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72
+BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2)
DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580
DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292
@@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
-TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check
-# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount
-# ENC0: 20*4*9 + 4 = 720+4 = 724
-# ENC1: 4*8*9 + 8 = 288+8 = 296
-# BN: 8*8*1 + 8 = 64+8 = 72
-# DEC1: 16*4*9 + 4 = 576+4 = 580
-# DEC0: 8*4*9 + 4 = 288+4 = 292
-# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck.
-# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4
-# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o<DEC1_OUT, i<DEC1_IN
-# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4
-# Total DEC1 = 580. OK that's right.
-# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064?
-# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK.
-# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964.
-# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table:
-# enc0: 20*4*9=720 +4 = 724
-# enc1: 4*8*9=288 +8 = 296
-# bottleneck: 8*8*1=64 +8 = 72
-# dec1: 16*4*9=576 +4 = 580
-# dec0: 8*4*9=288 +4 = 292
-# Total = 724+296+72+580+292 = 1964
-# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc.
-# I'll use the correct value derived from the formulas: 1964.
+TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS
+# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72)
# ---------------------------------------------------------------------------
# Helpers
@@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
def bottleneck_forward(enc1, w):
"""
- AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU
+ AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU
→ rgba32uint (f16, quarter-res). No FiLM.
enc1: (hH, hW, 8) f32 — half-res
+ Matches cnn_v3_bottleneck.wgsl exactly.
"""
hH, hW = enc1.shape[:2]
qH, qW = hH // 2, hW // 2
wo = BN_OFFSET
- # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL)
- avg = np.zeros((qH, qW, BN_IN), dtype=np.float32)
- for qy in range(qH):
- for qx in range(qW):
- s = np.zeros(BN_IN, dtype=np.float32)
- for dy in range(2):
- for dx in range(2):
- hy = min(qy * 2 + dy, hH - 1)
- hx = min(qx * 2 + dx, hW - 1)
- s += enc1[hy, hx, :]
- avg[qy, qx, :] = s * 0.25
+ def load_enc1_avg(qy, qx):
+ """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL)."""
+ if qy < 0 or qx < 0 or qy >= qH or qx >= qW:
+ return np.zeros(BN_IN, dtype=np.float32)
+ s = np.zeros(BN_IN, dtype=np.float32)
+ for dy in range(2):
+ for dx in range(2):
+ hy = min(qy * 2 + dy, hH - 1)
+ hx = min(qx * 2 + dx, hW - 1)
+ s += enc1[hy, hx, :]
+ return s * 0.25
- # 1x1 conv (no spatial loop, just channel dot-product)
+ # 3x3 conv with dilation=2 in quarter-res space
out = np.zeros((qH, qW, BN_OUT), dtype=np.float32)
for o in range(BN_OUT):
- bias = get_w(w, wo, BN_OUT * BN_IN + o)
- s = np.full((qH, qW), bias, dtype=np.float32)
- for i in range(BN_IN):
- wv = get_w(w, wo, o * BN_IN + i)
- s += wv * avg[:, :, i]
- out[:, :, o] = np.maximum(0.0, s)
+ bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o)
+ for qy in range(qH):
+ for qx in range(qW):
+ s = bias
+ for ky in range(-1, 2):
+ for kx in range(-1, 2):
+ feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2
+ ki = (ky + 1) * 3 + (kx + 1)
+ for i in range(BN_IN):
+ s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i]
+ out[qy, qx, o] = max(0.0, s)
return np.float16(out).astype(np.float32) # pack2x16float boundary