summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 23:20:25 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 23:20:25 +0100
commit25f3c735e304a9af7c0bf8f7d62228907d03e5c5 (patch)
treee30a4df053055c78b0612099ca3688c59c755b31 /training
parent6fa9ccf86b0bbefb48cefae19d4162115a3d63d3 (diff)
CNN v2 training: Fix float64/float32 dtype mismatch in depth feature
Cast depth array to float32 when provided, preventing torch Double/Float dtype mismatch during forward pass. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn_v2.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 70229ce..134a5ae 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0):
p0 = mip_rgb[:, :, 0].astype(np.float32)
p1 = mip_rgb[:, :, 1].astype(np.float32)
p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
+ p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)