diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-11 00:32:02 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-11 00:32:02 +0100 |
| commit | 01e640be66f9d72c22417403eb88e18d6747866f (patch) | |
| tree | c70ffb6abf0910f07d727212dbe4665b1c24a55e /training | |
| parent | c49d828f101b435d73a76fcfc8444cf76aeda22f (diff) | |
fix: Use patch-based inference to match CNN training distribution
Inference now tiles images into patches matching training patch size,
preventing distribution mismatch between patch training and full-image inference.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
| -rwxr-xr-x | training/train_cnn.py | 38 |
1 files changed, 28 insertions, 10 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index d8522ed..57b4da8 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -668,8 +668,8 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path): - """Run inference on single image to generate ground truth""" +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): + """Run patch-based inference to match training distribution""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -690,20 +690,37 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path): model.load_state_dict(checkpoint['model_state']) model.eval() - # Load image [0,1] + # Load image print(f"Loading input image: {input_path}") img = Image.open(input_path).convert('RGBA') - img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] + W, H = img.size - # Inference - print("Running inference...") - with torch.no_grad(): - out = model(img_tensor) # [1,3,H,W] in [0,1] + # Tile into patches + print(f"Processing {patch_size}×{patch_size} patches...") + output = np.zeros((H, W, 3), dtype=np.float32) + + for y in range(0, H, patch_size): + for x in range(0, W, patch_size): + x_end = min(x + patch_size, W) + y_end = min(y + patch_size, H) + + # Extract patch + patch = img.crop((x, y, x_end, y_end)) + patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w] + + # Inference + with torch.no_grad(): + out_patch = model(patch_tensor) # [1,3,h,w] + + # Write to output + out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy() + output[y:y_end, x:x_end] = out_np # Save print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) - transforms.ToPILImage()(out.squeeze(0)).save(output_path) + output_img = Image.fromarray((output * 255).astype(np.uint8)) + output_img.save(output_path) print("Done!") @@ -736,7 +753,8 @@ def main(): print("Error: --infer requires --export-only <checkpoint>") sys.exit(1) output_path = args.output or 'inference_output.png' - infer_from_checkpoint(checkpoint, args.infer, output_path) + patch_size = args.patch_size or 32 + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size) return # Export-only mode |
