summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index 5ad922e..1ea42a3 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -249,11 +249,9 @@ class SimpleCNN(nn.Module):
# Normalize RGBD to [-1,1]
x_norm = (x - 0.5) * 2.0
- # Compute coordinates [0,1] then normalize to [-1,1]
- y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
- x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
- y_coords = (y_coords - 0.5) * 2.0 # [-1,1]
- x_coords = (x_coords - 0.5) * 2.0 # [-1,1]
+ # Compute normalized coordinates [-1,1]
+ y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
# Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]