diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 22:42:45 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 22:42:45 +0100 |
| commit | f81a30d15e1e7db0492f45a0b9bec6aaa20ae5c2 (patch) | |
| tree | deb202a7d995895ec90e8ddc8c3fbf92082ea434 /training | |
| parent | 7c1f937222d0e36294ebd25db949c6227aed6985 (diff) | |
CNN v2: Use alpha channel for p3 depth feature + layer visualization
Training changes (train_cnn_v2.py):
- p3 now uses target image alpha channel (depth proxy for 2D images)
- Default changed from 0.0 → 1.0 (far plane semantics)
- Both PatchDataset and ImagePairDataset updated
Test tools (cnn_test.cc):
- New load_depth_from_alpha() extracts PNG alpha → p3 texture
- Fixed bind group layout: use UnfilterableFloat for R32Float depth
- Added --save-intermediates support for CNN v2:
* Each layer_N.png shows 4 channels horizontally (1812×345 grayscale)
* layers_composite.png stacks all layers vertically (1812×1380)
* static_features.png shows 4 feature channels horizontally
- Per-channel visualization enables debugging layer-by-layer differences
HTML tool (index.html):
- Extract alpha channel from input image → depth texture
- Matches training data distribution for validation
Note: Current weights trained with p3=0 are now mismatched. Both tools
use p3=alpha consistently, so outputs remain comparable for debugging.
Retrain required for optimal quality.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn_v2.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index abe07bc..70229ce 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -26,13 +26,13 @@ def compute_static_features(rgb, depth=None, mip_level=0): Args: rgb: (H, W, 3) RGB image [0, 1] - depth: (H, W) depth map [0, 1], optional + depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane) mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth) Returns: (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias] - Note: p0-p3 are parametric features generated from specified mip level + Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0 TODO: Binary format should support arbitrary layout and ordering for feature vector (7D), alongside mip-level indication. Current layout is hardcoded as: @@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0): p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) + p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) @@ -244,8 +244,11 @@ class PatchDataset(Dataset): input_patch = input_img[y1:y2, x1:x2] target_patch = target_img[y1:y2, x1:x2] # RGBA + # Extract depth from target alpha channel (or default to 1.0) + depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None + # Compute static features for patch - static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level) + static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) @@ -284,8 +287,11 @@ class ImagePairDataset(Dataset): input_img = np.array(input_pil) / 255.0 target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha + # Extract depth from target alpha channel (or default to 1.0) + depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None + # Compute static features - static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level) + static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level) # Input RGBD (mip 0) - add depth channel h, w = input_img.shape[:2] |
