summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 16:46:08 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 16:46:08 +0100
commit606a3e8027e901b5a3f9e68444d931982080bdd9 (patch)
tree1ec466f72709922002ad7d03aee34d8722d199d6 /training/train_cnn.py
parentf71e4b6c3ae7c2b5a0c71fa6b379c44b5d527874 (diff)
refactor: Use linspace(-1,1) directly for coords
Simplify coordinate initialization by generating [-1,1] range directly instead of [0,1] then normalizing. Mathematically equivalent, clearer. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 5ad922e..1ea42a3 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -249,11 +249,9 @@ class SimpleCNN(nn.Module):
# Normalize RGBD to [-1,1]
x_norm = (x - 0.5) * 2.0
- # Compute coordinates [0,1] then normalize to [-1,1]
- y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
- x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
- y_coords = (y_coords - 0.5) * 2.0 # [-1,1]
- x_coords = (x_coords - 0.5) * 2.0 # [-1,1]
+ # Compute normalized coordinates [-1,1]
+ y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
# Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]