From f81a30d15e1e7db0492f45a0b9bec6aaa20ae5c2 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 22:42:45 +0100 Subject: CNN v2: Use alpha channel for p3 depth feature + layer visualization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Training changes (train_cnn_v2.py): - p3 now uses target image alpha channel (depth proxy for 2D images) - Default changed from 0.0 → 1.0 (far plane semantics) - Both PatchDataset and ImagePairDataset updated Test tools (cnn_test.cc): - New load_depth_from_alpha() extracts PNG alpha → p3 texture - Fixed bind group layout: use UnfilterableFloat for R32Float depth - Added --save-intermediates support for CNN v2: * Each layer_N.png shows 4 channels horizontally (1812×345 grayscale) * layers_composite.png stacks all layers vertically (1812×1380) * static_features.png shows 4 feature channels horizontally - Per-channel visualization enables debugging layer-by-layer differences HTML tool (index.html): - Extract alpha channel from input image → depth texture - Matches training data distribution for validation Note: Current weights trained with p3=0 are now mismatched. Both tools use p3=alpha consistently, so outputs remain comparable for debugging. Retrain required for optimal quality. Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn_v2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'training') diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index abe07bc..70229ce 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -26,13 +26,13 @@ def compute_static_features(rgb, depth=None, mip_level=0): Args: rgb: (H, W, 3) RGB image [0, 1] - depth: (H, W) depth map [0, 1], optional + depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane) mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] - Note: p0-p3 are parametric features generated from specified mip level + Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0 TODO: Binary format should support arbitrary layout and ordering for feature vector (7D), alongside mip-level indication. Current layout is hardcoded as: @@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0): p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) + p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) @@ -244,8 +244,11 @@ class PatchDataset(Dataset): input_patch = input_img[y1:y2, x1:x2] target_patch = target_img[y1:y2, x1:x2] # RGBA + # Extract depth from target alpha channel (or default to 1.0) + depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None + # Compute static features for patch - static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level) + static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) @@ -284,8 +287,11 @@ class ImagePairDataset(Dataset): input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha + # Extract depth from target alpha channel (or default to 1.0) + depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None + # Compute static features - static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level) + static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] -- cgit v1.2.3