diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 14:01:30 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 14:01:30 +0100 |
| commit | bf33fee131b1eee03bc5a765ba360299bbcead06 (patch) | |
| tree | b6a076ec977bb250a13b6a69be1092a183ae18ce /cnn_v3 | |
| parent | 35355b17576e93b035a2a78ecd05771e98f068ee (diff) | |
refactor(cnn_v3): code review — comments, simplifications, test fix
C++:
- cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller)
- cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement
WGSL:
- cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why
array<u32,3> is invalid in uniform address space
- cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment
- cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv)
- cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits)
Python:
- cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with
pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any
channel count); add comment on normal zero-pad workaround
- export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants;
clarify weight count comments with Conv notation
Test:
- test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only)
handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
Diffstat (limited to 'cnn_v3')
| -rw-r--r-- | cnn_v3/shaders/cnn_v3_bottleneck.wgsl | 4 | ||||
| -rw-r--r-- | cnn_v3/shaders/cnn_v3_common.wgsl | 2 | ||||
| -rw-r--r-- | cnn_v3/shaders/cnn_v3_dec0.wgsl | 2 | ||||
| -rw-r--r-- | cnn_v3/shaders/cnn_v3_enc0.wgsl | 2 | ||||
| -rw-r--r-- | cnn_v3/src/cnn_v3_effect.cc | 5 | ||||
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 20 | ||||
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 14 |
7 files changed, 26 insertions, 23 deletions
diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl index 909fd41..e24586f 100644 --- a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl +++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl @@ -15,9 +15,7 @@ const BN_OUT: u32 = 8u; struct Params { weight_offset: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, + _pad0: u32, _pad1: u32, _pad2: u32, // 3 explicit pads: array<u32,3> invalid in uniform } @group(0) @binding(0) var enc1_tex: texture_2d<u32>; diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl index 54b0f3d..dbaf1b1 100644 --- a/cnn_v3/shaders/cnn_v3_common.wgsl +++ b/cnn_v3/shaders/cnn_v3_common.wgsl @@ -12,7 +12,7 @@ fn get_w(base: u32, idx: u32) -> f32 { } // Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout: -// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7) +// u32[0]: ch0 in low 16 bits, ch1 in high 16 bits; same for u32[1-3]) fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> { let t = textureLoad(tex, coord, 0); let v0 = unpack2x16float(t.x); diff --git a/cnn_v3/shaders/cnn_v3_dec0.wgsl b/cnn_v3/shaders/cnn_v3_dec0.wgsl index 7a4e7c9..a2a70ac 100644 --- a/cnn_v3/shaders/cnn_v3_dec0.wgsl +++ b/cnn_v3/shaders/cnn_v3_dec0.wgsl @@ -9,7 +9,7 @@ // [0 .. 8*4*9) conv: w[out][in][ky][kx] (in=8: 4 dec1 + 4 enc0 skip) // [288 .. +4) bias: b[out] // -// Parity note: sigmoid applied directly to dec0 output (matches train_cnn_v3.py forward()). +// Parity note: sigmoid applied after FiLM+ReLU, not after raw conv (matches train_cnn_v3.py). #include "cnn_v3/common" diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl index f52a167..e171ca7 100644 --- a/cnn_v3/shaders/cnn_v3_enc0.wgsl +++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl @@ -1,7 +1,7 @@ // CNN v3 — Encoder level 0 // Conv(20->4, 3x3, zero-pad) + FiLM + ReLU // -// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res +// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12ch u8norm) full-res // Output: enc0_out (rgba16float, 4ch) full-res // // Weight layout (f16, OIHW + bias): diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index 92178f7..759e7ef 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -198,11 +198,12 @@ void CNNv3Effect::declare_nodes(NodeRegistry& registry) { registry.declare_node(node_bottleneck_, NodeType::GBUF_RGBA32UINT, W / 4, H / 4); // dec1_tex: rgba16float half-res registry.declare_node(node_dec1_, NodeType::GBUF_ALBEDO, W / 2, H / 2); - // output_tex: rgba16float full-res (the declared output_nodes_[0]) + // output_nodes_[0]: rgba16float full-res — declared externally by caller } // --------------------------------------------------------------------------- -// set_film_params — simple linear mapping, no MLP yet +// set_film_params — simple linear mapping (placeholder, no MLP yet) +// TODO(phase-7): replace with CPU forward pass through cnn_v3_film_mlp.bin // --------------------------------------------------------------------------- void CNNv3Effect::upload_weights(WGPUQueue queue, const void* data, diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index ecdbd6b..8da276e 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -94,10 +94,11 @@ def depth_gradient(depth: np.ndarray) -> np.ndarray: def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: - """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C).""" - img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) - img = img.resize((w, h), Image.NEAREST) - return np.asarray(img, dtype=np.float32) / 255.0 + """Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss.""" + sh, sw = a.shape[:2] + ys = np.arange(h) * sh // h + xs = np.arange(w) * sw // w + return a[np.ix_(ys, xs)] def assemble_features(albedo: np.ndarray, normal: np.ndarray, @@ -291,7 +292,8 @@ class CNNv3Dataset(Dataset): if self.full_image: sz = self.image_size - def _resize_rgb(a): + def _resize_img(a): + # PIL handles RGB, RGBA, and grayscale by channel count img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 @@ -299,14 +301,14 @@ class CNNv3Dataset(Dataset): img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L') return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 - albedo = _resize_rgb(albedo) - normal = _resize_rgb(np.concatenate( - [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] + albedo = _resize_img(albedo) + normal = _resize_img(np.concatenate( + [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] # pad to 3ch for PIL depth = _resize_gray(depth) matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) - target = _resize_rgb(target) + target = _resize_img(target) else: ps = self.patch_size half = ps // 2 diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index a1ad42d..6d99af9 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -34,13 +34,15 @@ sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- -# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py) +# Weight layout constants — must stay in sync with: +# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) +# cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS # = 1964 |
