From bf33fee131b1eee03bc5a765ba360299bbcead06 Mon Sep 17 00:00:00 2001 From: skal Date: Sat, 21 Mar 2026 14:01:30 +0100 Subject: refactor(cnn_v3): code review — comments, simplifications, test fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit C++: - cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller) - cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement WGSL: - cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why array is invalid in uniform address space - cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment - cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv) - cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits) Python: - cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any channel count); add comment on normal zero-pad workaround - export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants; clarify weight count comments with Conv notation Test: - test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only) handoff(Gemini): CNN v3 review complete, 36/36 tests passing. --- cnn_v3/shaders/cnn_v3_bottleneck.wgsl | 4 +--- cnn_v3/shaders/cnn_v3_common.wgsl | 2 +- cnn_v3/shaders/cnn_v3_dec0.wgsl | 2 +- cnn_v3/shaders/cnn_v3_enc0.wgsl | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) (limited to 'cnn_v3/shaders') diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl index 909fd41..e24586f 100644 --- a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl +++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl @@ -15,9 +15,7 @@ const BN_OUT: u32 = 8u; struct Params { weight_offset: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, + _pad0: u32, _pad1: u32, _pad2: u32, // 3 explicit pads: array invalid in uniform } @group(0) @binding(0) var enc1_tex: texture_2d; diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl index 54b0f3d..dbaf1b1 100644 --- a/cnn_v3/shaders/cnn_v3_common.wgsl +++ b/cnn_v3/shaders/cnn_v3_common.wgsl @@ -12,7 +12,7 @@ fn get_w(base: u32, idx: u32) -> f32 { } // Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout: -// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7) +// u32[0]: ch0 in low 16 bits, ch1 in high 16 bits; same for u32[1-3]) fn unpack_8ch(tex: texture_2d, coord: vec2i) -> array { let t = textureLoad(tex, coord, 0); let v0 = unpack2x16float(t.x); diff --git a/cnn_v3/shaders/cnn_v3_dec0.wgsl b/cnn_v3/shaders/cnn_v3_dec0.wgsl index 7a4e7c9..a2a70ac 100644 --- a/cnn_v3/shaders/cnn_v3_dec0.wgsl +++ b/cnn_v3/shaders/cnn_v3_dec0.wgsl @@ -9,7 +9,7 @@ // [0 .. 8*4*9) conv: w[out][in][ky][kx] (in=8: 4 dec1 + 4 enc0 skip) // [288 .. +4) bias: b[out] // -// Parity note: sigmoid applied directly to dec0 output (matches train_cnn_v3.py forward()). +// Parity note: sigmoid applied after FiLM+ReLU, not after raw conv (matches train_cnn_v3.py). #include "cnn_v3/common" diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl index f52a167..e171ca7 100644 --- a/cnn_v3/shaders/cnn_v3_enc0.wgsl +++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl @@ -1,7 +1,7 @@ // CNN v3 — Encoder level 0 // Conv(20->4, 3x3, zero-pad) + FiLM + ReLU // -// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res +// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12ch u8norm) full-res // Output: enc0_out (rgba16float, 4ch) full-res // // Weight layout (f16, OIHW + bias): -- cgit v1.2.3