From 25f3c735e304a9af7c0bf8f7d62228907d03e5c5 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 23:20:25 +0100 Subject: CNN v2 training: Fix float64/float32 dtype mismatch in depth feature Cast depth array to float32 when provided, preventing torch Double/Float dtype mismatch during forward pass. Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 70229ce..134a5ae 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0): p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane + p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) -- cgit v1.2.3