diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-11 16:46:08 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-11 16:46:08 +0100 |
| commit | 606a3e8027e901b5a3f9e68444d931982080bdd9 (patch) | |
| tree | 1ec466f72709922002ad7d03aee34d8722d199d6 /training | |
| parent | f71e4b6c3ae7c2b5a0c71fa6b379c44b5d527874 (diff) | |
refactor: Use linspace(-1,1) directly for coords
Simplify coordinate initialization by generating [-1,1] range directly
instead of [0,1] then normalizing. Mathematically equivalent, clearer.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 5ad922e..1ea42a3 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -249,11 +249,9 @@ class SimpleCNN(nn.Module): # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 - # Compute coordinates [0,1] then normalize to [-1,1] - y_coords = torch.linspace(0, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) - x_coords = torch.linspace(0, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) - y_coords = (y_coords - 0.5) * 2.0 # [-1,1] - x_coords = (x_coords - 0.5) * 2.0 # [-1,1] + # Compute normalized coordinates [-1,1] + y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) + x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] |
