diff options
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn_v2.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index d80e3a5..9e5df2f 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -121,7 +121,7 @@ class CNNv2(nn.Module): # Layer 0: input RGBD (4D) + static (8D) = 12D x = torch.cat([input_rgbd, static_features], dim=1) x = self.layers[0](x) - x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 + x = torch.sigmoid(x) # Soft [0,1] for layer 0 # Layer 1+: previous (4D) + static (8D) = 12D for i in range(1, self.num_layers): @@ -130,7 +130,7 @@ class CNNv2(nn.Module): if i < self.num_layers - 1: x = F.relu(x) else: - x = torch.clamp(x, 0, 1) # Final output [0,1] + x = torch.sigmoid(x) # Soft [0,1] for final layer return x |
