summaryrefslogtreecommitdiff
path: root/cnn_v3/training/cnn_v3_utils.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
committerskal <pascal.massimino@gmail.com>2026-03-21 14:01:30 +0100
commitbf33fee131b1eee03bc5a765ba360299bbcead06 (patch)
treeb6a076ec977bb250a13b6a69be1092a183ae18ce /cnn_v3/training/cnn_v3_utils.py
parent35355b17576e93b035a2a78ecd05771e98f068ee (diff)
refactor(cnn_v3): code review — comments, simplifications, test fix
C++: - cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller) - cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement WGSL: - cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why array<u32,3> is invalid in uniform address space - cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment - cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv) - cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits) Python: - cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any channel count); add comment on normal zero-pad workaround - export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants; clarify weight count comments with Conv notation Test: - test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only) handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
Diffstat (limited to 'cnn_v3/training/cnn_v3_utils.py')
-rw-r--r--cnn_v3/training/cnn_v3_utils.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index ecdbd6b..8da276e 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -94,10 +94,11 @@ def depth_gradient(depth: np.ndarray) -> np.ndarray:
def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
- """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C)."""
- img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
- img = img.resize((w, h), Image.NEAREST)
- return np.asarray(img, dtype=np.float32) / 255.0
+ """Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss."""
+ sh, sw = a.shape[:2]
+ ys = np.arange(h) * sh // h
+ xs = np.arange(w) * sw // w
+ return a[np.ix_(ys, xs)]
def assemble_features(albedo: np.ndarray, normal: np.ndarray,
@@ -291,7 +292,8 @@ class CNNv3Dataset(Dataset):
if self.full_image:
sz = self.image_size
- def _resize_rgb(a):
+ def _resize_img(a):
+ # PIL handles RGB, RGBA, and grayscale by channel count
img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8))
return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
@@ -299,14 +301,14 @@ class CNNv3Dataset(Dataset):
img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L')
return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0
- albedo = _resize_rgb(albedo)
- normal = _resize_rgb(np.concatenate(
- [normal, np.zeros_like(normal[..., :1])], -1))[..., :2]
+ albedo = _resize_img(albedo)
+ normal = _resize_img(np.concatenate(
+ [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] # pad to 3ch for PIL
depth = _resize_gray(depth)
matid = _resize_gray(matid)
shadow = _resize_gray(shadow)
transp = _resize_gray(transp)
- target = _resize_rgb(target)
+ target = _resize_img(target)
else:
ps = self.patch_size
half = ps // 2