diff options
Diffstat (limited to 'cnn_v3/training')
| -rw-r--r-- | cnn_v3/training/cnn_v3_utils.py | 20 | ||||
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 14 |
2 files changed, 19 insertions, 15 deletions
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py index ecdbd6b..8da276e 100644 --- a/cnn_v3/training/cnn_v3_utils.py +++ b/cnn_v3/training/cnn_v3_utils.py @@ -94,10 +94,11 @@ def depth_gradient(depth: np.ndarray) -> np.ndarray: def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray: - """Nearest-neighbour upsample (H,W,C) f32 [0,1] to (h,w,C).""" - img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) - img = img.resize((w, h), Image.NEAREST) - return np.asarray(img, dtype=np.float32) / 255.0 + """Nearest-neighbour upsample (H,W,C) f32 to (h,w,C) — pure numpy, no precision loss.""" + sh, sw = a.shape[:2] + ys = np.arange(h) * sh // h + xs = np.arange(w) * sw // w + return a[np.ix_(ys, xs)] def assemble_features(albedo: np.ndarray, normal: np.ndarray, @@ -291,7 +292,8 @@ class CNNv3Dataset(Dataset): if self.full_image: sz = self.image_size - def _resize_rgb(a): + def _resize_img(a): + # PIL handles RGB, RGBA, and grayscale by channel count img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8)) return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 @@ -299,14 +301,14 @@ class CNNv3Dataset(Dataset): img = Image.fromarray((np.clip(a, 0, 1) * 255).astype(np.uint8), mode='L') return np.asarray(img.resize((sz, sz), Image.LANCZOS), dtype=np.float32) / 255.0 - albedo = _resize_rgb(albedo) - normal = _resize_rgb(np.concatenate( - [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] + albedo = _resize_img(albedo) + normal = _resize_img(np.concatenate( + [normal, np.zeros_like(normal[..., :1])], -1))[..., :2] # pad to 3ch for PIL depth = _resize_gray(depth) matid = _resize_gray(matid) shadow = _resize_gray(shadow) transp = _resize_gray(transp) - target = _resize_rgb(target) + target = _resize_img(target) else: ps = self.patch_size half = ps // 2 diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index a1ad42d..6d99af9 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -34,13 +34,15 @@ sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- -# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py) +# Weight layout constants — must stay in sync with: +# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) +# cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS # = 1964 |
