diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index dc14192..ef7a0ae 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -240,10 +240,12 @@ class SimpleCNN(nn.Module): # Final layer: 7→1 (grayscale output) self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) - def forward(self, x): + def forward(self, x, return_intermediates=False): # x: [B,4,H,W] - RGBD input (D = 1/z) B, C, H, W = x.shape + intermediates = [] if return_intermediates else None + # Normalize RGBD to [-1,1] x_norm = (x - 0.5) * 2.0 @@ -261,18 +263,26 @@ class SimpleCNN(nn.Module): layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] out = self.layers[0](layer0_input) # [B,4,H,W] out = torch.tanh(out) # [-1,1] + if return_intermediates: + intermediates.append(out.clone()) # Inner layers for i in range(1, len(self.layers)-1): layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[i](layer_input) out = torch.tanh(out) + if return_intermediates: + intermediates.append(out.clone()) # Final layer (grayscale output) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) out = self.layers[-1](final_input) # [B,1,H,W] out = torch.sigmoid(out) # Map to [0,1] with smooth gradients - return out.expand(-1, 3, -1, -1) + final_out = out.expand(-1, 3, -1, -1) + + if return_intermediates: + return final_out, intermediates + return final_out def generate_layer_shader(output_path, num_layers, kernel_sizes): @@ -693,7 +703,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None): """Run sliding-window inference to match WGSL shader behavior""" if not os.path.exists(checkpoint_path): @@ -724,16 +734,35 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 # Process full image with sliding window (matches WGSL shader) print(f"Processing full image ({W}×{H}) with sliding window...") with torch.no_grad(): - output_tensor = model(img_tensor) # [1,3,H,W] + if save_intermediates: + output_tensor, intermediates = model(img_tensor, return_intermediates=True) + else: + output_tensor = model(img_tensor) # [1,3,H,W] # Convert to numpy output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() - # Save + # Save final output print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) output_img = Image.fromarray((output * 255).astype(np.uint8)) output_img.save(output_path) + + # Save intermediates if requested + if save_intermediates: + os.makedirs(save_intermediates, exist_ok=True) + print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}") + for layer_idx, layer_tensor in enumerate(intermediates): + # Convert [-1,1] to [0,1] for visualization + layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 + # Take first channel for 4-channel intermediate layers + if layer_data.shape[2] == 4: + layer_data = layer_data[:, :, :3] # Show RGB only + layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8)) + layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') + layer_img.save(layer_path) + print(f" Saved layer {layer_idx} to {layer_path}") + print("Done!") @@ -758,6 +787,7 @@ def main(): help='Salient point detector for patch extraction (default: harris)') parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') + parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') args = parser.parse_args() @@ -769,7 +799,7 @@ def main(): sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 - infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size) + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates) return # Export-only mode |
