diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 23:20:25 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 23:20:25 +0100 |
| commit | 25f3c735e304a9af7c0bf8f7d62228907d03e5c5 (patch) | |
| tree | e30a4df053055c78b0612099ca3688c59c755b31 /training | |
| parent | 6fa9ccf86b0bbefb48cefae19d4162115a3d63d3 (diff) | |
CNN v2 training: Fix float64/float32 dtype mismatch in depth feature
Cast depth array to float32 when provided, preventing torch Double/Float
dtype mismatch during forward pass.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn_v2.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 70229ce..134a5ae 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0): p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane + p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) |
