summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
commitbf33fee131b1eee03bc5a765ba360299bbcead06 (patch)
treeb6a076ec977bb250a13b6a69be1092a183ae18ce
parent35355b17576e93b035a2a78ecd05771e98f068ee (diff)
refactor(cnn_v3): code review — comments, simplifications, test fix
C++: - cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller) - cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement WGSL: - cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why array<u32,3> is invalid in uniform address space - cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment - cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv) - cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits) Python: - cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any channel count); add comment on normal zero-pad workaround - export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants; clarify weight count comments with Conv notation Test: - test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only) handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
-rw-r--r--cnn_v3/shaders/cnn_v3_bottleneck.wgsl4
-rw-r--r--cnn_v3/shaders/cnn_v3_common.wgsl2
-rw-r--r--cnn_v3/shaders/cnn_v3_dec0.wgsl2
-rw-r--r--cnn_v3/shaders/cnn_v3_enc0.wgsl2
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc5
-rw-r--r--cnn_v3/training/cnn_v3_utils.py20
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py14
-rw-r--r--src/tests/gpu/test_cnn_v3_parity.cc15
8 files changed, 35 insertions, 29 deletions
diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
index 909fd41..e24586f 100644
--- a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
+++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
@@ -15,9 +15,7 @@ const BN_OUT: u32 = 8u;
struct Params {
weight_offset: u32,
- _pad0: u32,
- _pad1: u32,
- _pad2: u32,
+ _pad0: u32, _pad1: u32, _pad2: u32, // 3 explicit pads: array<u32,3> invalid in uniform
}
@group(0) @binding(0) var enc1_tex: texture_2d<u32>;
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl
index 54b0f3d..dbaf1b1 100644
--- a/cnn_v3/shaders/cnn_v3_common.wgsl
+++ b/cnn_v3/shaders/cnn_v3_common.wgsl
@@ -12,7 +12,7 @@ fn get_w(base: u32, idx: u32) -> f32 {
}
// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
-// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
+// u32[0]: ch0 in low 16 bits, ch1 in high 16 bits; same for u32[1-3])
fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
let t = textureLoad(tex, coord, 0);
let v0 = unpack2x16float(t.x);
diff --git a/cnn_v3/shaders/cnn_v3_dec0.wgsl b/cnn_v3/shaders/cnn_v3_dec0.wgsl
index 7a4e7c9..a2a70ac 100644
--- a/cnn_v3/shaders/cnn_v3_dec0.wgsl
+++ b/cnn_v3/shaders/cnn_v3_dec0.wgsl
@@ -9,7 +9,7 @@
// [0 .. 8*4*9) conv: w[out][in][ky][kx] (in=8: 4 dec1 + 4 enc0 skip)
// [288 .. +4) bias: b[out]
//
-// Parity note: sigmoid applied directly to dec0 output (matches train_cnn_v3.py forward()).
+// Parity note: sigmoid applied after FiLM+ReLU, not after raw conv (matches train_cnn_v3.py).
#include "cnn_v3/common"
diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl
index f52a167..e171ca7 100644
--- a/cnn_v3/shaders/cnn_v3_enc0.wgsl
+++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl
@@ -1,7 +1,7 @@
// CNN v3 — Encoder level 0
// Conv(20->4, 3x3, zero-pad) + FiLM + ReLU
//
-// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res
+// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12ch u8norm) full-res
// Output: enc0_out (rgba16float, 4ch) full-res
//
// Weight layout (f16, OIHW + bias):
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index 92178f7..759e7ef 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -198,11 +198,12 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) {
registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4);
// dec1_tex: rgba16float half-res
registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2);
- // output_tex: rgba16float full-res (the declared output_nodes_[0])
+ // output_nodes_[0]: rgba16float full-res — declared externally by caller
}
// ---------------------------------------------------------------------------
-// set_film_params — simple linear mapping, no MLP yet
+// set_film_params — simple linear mapping (placeholder, no MLP yet)
+// TODO(phase-7): replace with CPU forward pass through cnn_v3_film_mlp.bin
// ---------------------------------------------------------------------------
void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data,
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index ecdbd6b..8da276e 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -94,10 +94,11 @@ def depth_gradient(depth: np.ndarray) -> np.ndarray:
def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
- """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C)."""
- img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
- img = img.resize((w, h), Image.NEAREST)
- return np.asarray(img, dtype=np.float32) / 255.0
+ """Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss."""
+ sh, sw = a.shape[:2]
+ ys = np.arange(h) * sh // h
+ xs = np.arange(w) * sw // w
+ return a[np.ix_(ys, xs)]
def assemble_features(albedo: np.ndarray, normal: np.ndarray,
@@ -291,7 +292,8 @@ class CNNv3Dataset(Dataset):
if self.full_image:
sz = self.image_size
- def _resize_rgb(a):
+ def _resize_img(a):
+ # PIL handles RGB, RGBA, and grayscale by channel count
img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
@@ -299,14 +301,14 @@ class CNNv3Dataset(Dataset):
img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L')
return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
- albedo = _resize_rgb(albedo)
- normal = _resize_rgb(np.concatenate(
- [normal, np.zeros_like(normal[..., :1])], -1))[..., :2]
+ albedo = _resize_img(albedo)
+ normal = _resize_img(np.concatenate(
+ [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] # pad to 3ch for PIL
depth = _resize_gray(depth)
matid = _resize_gray(matid)
shadow = _resize_gray(shadow)
transp = _resize_gray(transp)
- target = _resize_rgb(target)
+ target = _resize_img(target)
else:
ps = self.patch_size
half = ps // 2
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index a1ad42d..6d99af9 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -34,13 +34,15 @@ sys.path.insert(0, str(Path(__file__).parent))
from train_cnn_v3 import CNNv3
# ---------------------------------------------------------------------------
-# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py)
+# Weight layout constants — must stay in sync with:
+# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
+# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296
-BN_WEIGHTS = 8 * 8 * 1 + 8 # 72
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
+BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
# = 1964
diff --git a/src/tests/gpu/test_cnn_v3_parity.cc b/src/tests/gpu/test_cnn_v3_parity.cc
index 608decb..4b41c94 100644
--- a/src/tests/gpu/test_cnn_v3_parity.cc
+++ b/src/tests/gpu/test_cnn_v3_parity.cc
@@ -301,7 +301,8 @@ static int test_random_weights() {
float err = fabsf(enc0_pixels[i] - ref);
if (err > enc0_max_err) { enc0_max_err = err; enc0_worst = i; }
}
- if (enc0_max_err > tol) {
+ bool enc0_ok = (enc0_max_err <= tol);
+ if (!enc0_ok) {
int px = enc0_worst / 4, ch = enc0_worst % 4;
fprintf(stderr, " ✗ enc0 mismatch: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
@@ -321,7 +322,8 @@ static int test_random_weights() {
float err = fabsf(dec1_pixels[i] - ref);
if (err > dec1_max_err) { dec1_max_err = err; dec1_worst = i; }
}
- if (dec1_max_err > tol) {
+ bool dec1_ok = (dec1_max_err <= tol);
+ if (!dec1_ok) {
int px = dec1_worst / 4, ch = dec1_worst % 4;
fprintf(stderr, " ✗ dec1 mismatch: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
@@ -342,17 +344,18 @@ static int test_random_weights() {
if (err > max_err) { max_err = err; worst = i; }
}
- if (max_err > tol) {
+ bool out_ok = (max_err <= tol);
+ if (!out_ok) {
int px = worst / 4, ch = worst % 4;
fprintf(stderr, " ✗ random_weights: max_err=%.5f > %.5f at px=%d ch=%d"
" gpu=%.5f ref=%.5f\n",
max_err, tol, px, ch,
pixels[worst],
fp16_bits_to_f32(kCnnV3ExpectedOutputU16[worst]));
- return 0;
+ } else {
+ fprintf(stdout, " ✓ random_weights: max_err=%.2e OK\n", max_err);
}
- fprintf(stdout, " ✓ random_weights: max_err=%.2e OK\n", max_err);
- return 1;
+ return (enc0_ok && dec1_ok && out_ok) ? 1 : 0;
}
// ---------------------------------------------------------------------------