summaryrefslogtreecommitdiff
path: root/src/tests
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
commitbf33fee131b1eee03bc5a765ba360299bbcead06 (patch)
treeb6a076ec977bb250a13b6a69be1092a183ae18ce /src/tests
parent35355b17576e93b035a2a78ecd05771e98f068ee (diff)
refactor(cnn_v3): code review — comments, simplifications, test fix
C++: - cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller) - cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement WGSL: - cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why array<u32,3> is invalid in uniform address space - cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment - cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv) - cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits) Python: - cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any channel count); add comment on normal zero-pad workaround - export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants; clarify weight count comments with Conv notation Test: - test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only) handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
Diffstat (limited to 'src/tests')
-rw-r--r--src/tests/gpu/test_cnn_v3_parity.cc15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/tests/gpu/test_cnn_v3_parity.cc b/src/tests/gpu/test_cnn_v3_parity.cc
index 608decb..4b41c94 100644
--- a/src/tests/gpu/test_cnn_v3_parity.cc
+++ b/src/tests/gpu/test_cnn_v3_parity.cc
@@ -301,7 +301,8 @@ static int test_random_weights() {
float err = fabsf(enc0_pixels[i] - ref);
if (err > enc0_max_err) { enc0_max_err = err; enc0_worst = i; }
}
- if (enc0_max_err > tol) {
+ bool enc0_ok = (enc0_max_err <= tol);
+ if (!enc0_ok) {
int px = enc0_worst / 4, ch = enc0_worst % 4;
fprintf(stderr, " ✗ enc0 mismatch: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
@@ -321,7 +322,8 @@ static int test_random_weights() {
float err = fabsf(dec1_pixels[i] - ref);
if (err > dec1_max_err) { dec1_max_err = err; dec1_worst = i; }
}
- if (dec1_max_err > tol) {
+ bool dec1_ok = (dec1_max_err <= tol);
+ if (!dec1_ok) {
int px = dec1_worst / 4, ch = dec1_worst % 4;
fprintf(stderr, " ✗ dec1 mismatch: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
@@ -342,17 +344,18 @@ static int test_random_weights() {
if (err > max_err) { max_err = err; worst = i; }
}
- if (max_err > tol) {
+ bool out_ok = (max_err <= tol);
+ if (!out_ok) {
int px = worst / 4, ch = worst % 4;
fprintf(stderr, " ✗ random_weights: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
max_err, tol, px, ch,
pixels[worst],
fp16_bits_to_f32(kCnnV3ExpectedOutputU16[worst]));
- return 0;
+ } else {
+ fprintf(stdout, " ✓ random_weights: max_err=%.2e OK\n", max_err);
}
- fprintf(stdout, " ✓ random_weights: max_err=%.2e OK\n", max_err);
- return 1;
+ return (enc0_ok && dec1_ok && out_ok) ? 1 : 0;
}
// ---------------------------------------------------------------------------