summaryrefslogtreecommitdiff
path: root/cnn_v3/shaders/cnn_v3_common.wgsl
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
commitbf33fee131b1eee03bc5a765ba360299bbcead06 (patch)
treeb6a076ec977bb250a13b6a69be1092a183ae18ce /cnn_v3/shaders/cnn_v3_common.wgsl
parent35355b17576e93b035a2a78ecd05771e98f068ee (diff)
refactor(cnn_v3): code review — comments, simplifications, test fix
C++: - cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller) - cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement WGSL: - cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why array<u32,3> is invalid in uniform address space - cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment - cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv) - cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits) Python: - cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any channel count); add comment on normal zero-pad workaround - export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants; clarify weight count comments with Conv notation Test: - test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only) handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
Diffstat (limited to 'cnn_v3/shaders/cnn_v3_common.wgsl')
-rw-r--r--cnn_v3/shaders/cnn_v3_common.wgsl2
1 files changed, 1 insertions, 1 deletions
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl
index 54b0f3d..dbaf1b1 100644
--- a/cnn_v3/shaders/cnn_v3_common.wgsl
+++ b/cnn_v3/shaders/cnn_v3_common.wgsl
@@ -12,7 +12,7 @@ fn get_w(base: u32, idx: u32) -> f32 {
}
// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
-// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
+// u32[0]: ch0 in low 16 bits, ch1 in high 16 bits; same for u32[1-3])
fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
let t = textureLoad(tex, coord, 0);
let v0 = unpack2x16float(t.x);