diff options
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 70229ce..9e5df2f 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -61,7 +61,7 @@ def compute_static_features(rgb, depth=None, mip_level=0): p0 = mip_rgb[:, :, 0].astype(np.float32) p1 = mip_rgb[:, :, 1].astype(np.float32) p2 = mip_rgb[:, :, 2].astype(np.float32) - p3 = depth if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane + p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) @@ -121,7 +121,7 @@ class CNNv2(nn.Module): # Layer 0: input RGBD (4D) + static (8D) = 12D x = torch.cat([input_rgbd, static_features], dim=1) x = self.layers[0](x) - x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 + x = torch.sigmoid(x) # Soft [0,1] for layer 0 # Layer 1+: previous (4D) + static (8D) = 12D for i in range(1, self.num_layers): @@ -130,7 +130,7 @@ class CNNv2(nn.Module): if i < self.num_layers - 1: x = F.relu(x) else: - x = torch.clamp(x, 0, 1) # Final output [0,1] + x = torch.sigmoid(x) # Soft [0,1] for final layer return x @@ -329,6 +329,9 @@ def train(args): kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] if len(kernel_sizes) == 1: kernel_sizes = kernel_sizes * args.num_layers + else: + # When multiple kernel sizes provided, derive num_layers from list length + args.num_layers = len(kernel_sizes) # Create model model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device) @@ -397,6 +400,25 @@ def train(args): }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") + # Always save final checkpoint + print() # Newline after training + final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth" + final_checkpoint.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + 'epoch': args.epochs, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': avg_loss, + 'config': { + 'kernel_sizes': kernel_sizes, + 'num_layers': args.num_layers, + 'mip_level': args.mip_level, + 'grayscale_loss': args.grayscale_loss, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias'] + } + }, final_checkpoint) + print(f" → Saved final checkpoint: {final_checkpoint}") + print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s") return model |
