summaryrefslogtreecommitdiff
path: root/training/train_cnn_v2.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 22:42:45 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 22:42:45 +0100
commitf81a30d15e1e7db0492f45a0b9bec6aaa20ae5c2 (patch)
treedeb202a7d995895ec90e8ddc8c3fbf92082ea434 /training/train_cnn_v2.py
parent7c1f937222d0e36294ebd25db949c6227aed6985 (diff)
CNN v2: Use alpha channel for p3 depth feature + layer visualization
Training changes (train_cnn_v2.py): - p3 now uses target image alpha channel (depth proxy for 2D images) - Default changed from 0.0 → 1.0 (far plane semantics) - Both PatchDataset and ImagePairDataset updated Test tools (cnn_test.cc): - New load_depth_from_alpha() extracts PNG alpha → p3 texture - Fixed bind group layout: use UnfilterableFloat for R32Float depth - Added --save-intermediates support for CNN v2: * Each layer_N.png shows 4 channels horizontally (1812×345 grayscale) * layers_composite.png stacks all layers vertically (1812×1380) * static_features.png shows 4 feature channels horizontally - Per-channel visualization enables debugging layer-by-layer differences HTML tool (index.html): - Extract alpha channel from input image → depth texture - Matches training data distribution for validation Note: Current weights trained with p3=0 are now mismatched. Both tools use p3=alpha consistently, so outputs remain comparable for debugging. Retrain required for optimal quality. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn_v2.py')
-rwxr-xr-xtraining/train_cnn_v2.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index abe07bc..70229ce 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -26,13 +26,13 @@ def compute_static_features(rgb, depth=None, mip_level=0):
Args:
rgb: (H, W, 3) RGB image [0, 1]
- depth: (H, W) depth map [0, 1], optional
+ depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
Returns:
(H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
- Note: p0-p3 are parametric features generated from specified mip level
+ Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
alongside mip-level indication. Current layout is hardcoded as:
@@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0):
p0 = mip_rgb[:, :, 0].astype(np.float32)
p1 = mip_rgb[:, :, 1].astype(np.float32)
p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+ p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -244,8 +244,11 @@ class PatchDataset(Dataset):
input_patch = input_img[y1:y2, x1:x2]
target_patch = target_img[y1:y2, x1:x2] # RGBA
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
+
# Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
@@ -284,8 +287,11 @@ class ImagePairDataset(Dataset):
input_img = np.array(input_pil) / 255.0
target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
+
# Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
h, w = input_img.shape[:2]