summaryrefslogtreecommitdiff
path: root/training/train_cnn_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn_v2.py')
-rwxr-xr-xtraining/train_cnn_v2.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index abe07bc..70229ce 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -26,13 +26,13 @@ def compute_static_features(rgb, depth=None, mip_level=0):
Args:
rgb: (H, W, 3) RGB image [0, 1]
- depth: (H, W) depth map [0, 1], optional
+ depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
Returns:
(H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
- Note: p0-p3 are parametric features generated from specified mip level
+ Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
alongside mip-level indication. Current layout is hardcoded as:
@@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0):
p0 = mip_rgb[:, :, 0].astype(np.float32)
p1 = mip_rgb[:, :, 1].astype(np.float32)
p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+ p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -244,8 +244,11 @@ class PatchDataset(Dataset):
input_patch = input_img[y1:y2, x1:x2]
target_patch = target_img[y1:y2, x1:x2] # RGBA
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
+
# Compute static features for patch
- static_feat = compute_static_features(input_patch.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
@@ -284,8 +287,11 @@ class ImagePairDataset(Dataset):
input_img = np.array(input_pil) / 255.0
target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
+
# Compute static features
- static_feat = compute_static_features(input_img.astype(np.float32), mip_level=self.mip_level)
+ static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
# Input RGBD (mip 0) - add depth channel
h, w = input_img.shape[:2]