summaryrefslogtreecommitdiff
path: root/training/train_cnn_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn_v2.py')
-rwxr-xr-xtraining/train_cnn_v2.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 70229ce..9e5df2f 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0):
p0 = mip_rgb[:, :, 0].astype(np.float32)
p1 = mip_rgb[:, :, 1].astype(np.float32)
p2 = mip_rgb[:, :, 2].astype(np.float32)
- p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
+ p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -121,7 +121,7 @@ class CNNv2(nn.Module):
# Layer 0: input RGBD (4D) + static (8D) = 12D
x = torch.cat([input_rgbd, static_features], dim=1)
x = self.layers[0](x)
- x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0
+ x = torch.sigmoid(x) # Soft [0,1] for layer 0
# Layer 1+: previous (4D) + static (8D) = 12D
for i in range(1, self.num_layers):
@@ -130,7 +130,7 @@ class CNNv2(nn.Module):
if i < self.num_layers - 1:
x = F.relu(x)
else:
- x = torch.clamp(x, 0, 1) # Final output [0,1]
+ x = torch.sigmoid(x) # Soft [0,1] for final layer
return x
@@ -329,6 +329,9 @@ def train(args):
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
if len(kernel_sizes) == 1:
kernel_sizes = kernel_sizes * args.num_layers
+ else:
+ # When multiple kernel sizes provided, derive num_layers from list length
+ args.num_layers = len(kernel_sizes)
# Create model
model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
@@ -397,6 +400,25 @@ def train(args):
}, checkpoint_path)
print(f" → Saved checkpoint: {checkpoint_path}")
+ # Always save final checkpoint
+ print() # Newline after training
+ final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
+ final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': args.epochs,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
+ }
+ }, final_checkpoint)
+ print(f" → Saved final checkpoint: {final_checkpoint}")
+
print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
return model