summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 10:51:06 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 10:51:06 +0100
commit4da0a3a5369142078fd7c681e3f0f1817bd6e2f3 (patch)
treed69429d6800dad0bb819f164122df634543796a5 /training
parent7dd1ac57178055aa8407777d4fb03787e21e6f66 (diff)
add --save-intermediates to train.py and cnn_test
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py42
1 files changed, 36 insertions, 6 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index dc14192..ef7a0ae 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -240,10 +240,12 @@ class SimpleCNN(nn.Module):
# Final layer: 7→1 (grayscale output)
self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
- def forward(self, x):
+ def forward(self, x, return_intermediates=False):
# x: [B,4,H,W] - RGBD input (D = 1/z)
B, C, H, W = x.shape
+ intermediates = [] if return_intermediates else None
+
# Normalize RGBD to [-1,1]
x_norm = (x - 0.5) * 2.0
@@ -261,18 +263,26 @@ class SimpleCNN(nn.Module):
layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
out = self.layers[0](layer0_input) # [B,4,H,W]
out = torch.tanh(out) # [-1,1]
+ if return_intermediates:
+ intermediates.append(out.clone())
# Inner layers
for i in range(1, len(self.layers)-1):
layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
out = self.layers[i](layer_input)
out = torch.tanh(out)
+ if return_intermediates:
+ intermediates.append(out.clone())
# Final layer (grayscale output)
final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
out = self.layers[-1](final_input) # [B,1,H,W]
out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
- return out.expand(-1, 3, -1, -1)
+ final_out = out.expand(-1, 3, -1, -1)
+
+ if return_intermediates:
+ return final_out, intermediates
+ return final_out
def generate_layer_shader(output_path, num_layers, kernel_sizes):
@@ -693,7 +703,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
-def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32):
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None):
"""Run sliding-window inference to match WGSL shader behavior"""
if not os.path.exists(checkpoint_path):
@@ -724,16 +734,35 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
# Process full image with sliding window (matches WGSL shader)
print(f"Processing full image ({W}×{H}) with sliding window...")
with torch.no_grad():
- output_tensor = model(img_tensor) # [1,3,H,W]
+ if save_intermediates:
+ output_tensor, intermediates = model(img_tensor, return_intermediates=True)
+ else:
+ output_tensor = model(img_tensor) # [1,3,H,W]
# Convert to numpy
output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
- # Save
+ # Save final output
print(f"Saving output to: {output_path}")
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
output_img = Image.fromarray((output * 255).astype(np.uint8))
output_img.save(output_path)
+
+ # Save intermediates if requested
+ if save_intermediates:
+ os.makedirs(save_intermediates, exist_ok=True)
+ print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}")
+ for layer_idx, layer_tensor in enumerate(intermediates):
+ # Convert [-1,1] to [0,1] for visualization
+ layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
+ # Take first channel for 4-channel intermediate layers
+ if layer_data.shape[2] == 4:
+ layer_data = layer_data[:, :, :3] # Show RGB only
+ layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8))
+ layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
+ layer_img.save(layer_path)
+ print(f" Saved layer {layer_idx} to {layer_path}")
+
print("Done!")
@@ -758,6 +787,7 @@ def main():
help='Salient point detector for patch extraction (default: harris)')
parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
+ parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
args = parser.parse_args()
@@ -769,7 +799,7 @@ def main():
sys.exit(1)
output_path = args.output or 'inference_output.png'
patch_size = args.patch_size or 32
- infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size)
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates)
return
# Export-only mode