diff options
| -rwxr-xr-x | training/train_cnn.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 5ad922e..1ea42a3 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -249,11 +249,9 @@ class SimpleCNN(nn.Module): # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 - # Compute coordinates [0,1] then normalize to [-1,1] - y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) - x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) - y_coords = (y_coords - 0.5) * 2.0 # [-1,1] - x_coords = (x_coords - 0.5) * 2.0 # [-1,1] + # Compute normalized coordinates [-1,1] + y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) + x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] |
