summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-11 00:32:02 +0100
committerskal <pascal.massimino@gmail.com>2026-02-11 00:32:02 +0100
commit01e640be66f9d72c22417403eb88e18d6747866f (patch)
treec70ffb6abf0910f07d727212dbe4665b1c24a55e /training
parentc49d828f101b435d73a76fcfc8444cf76aeda22f (diff)
fix: Use patch-based inference to match CNN training distribution
Inference now tiles images into patches matching training patch size, preventing distribution mismatch between patch training and full-image inference. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py38
1 files changed, 28 insertions, 10 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index d8522ed..57b4da8 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -668,8 +668,8 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
print("Export complete!")
-def infer_from_checkpoint(checkpoint_path, input_path, output_path):
- """Run inference on single image to generate ground truth"""
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32):
+ """Run patch-based inference to match training distribution"""
if not os.path.exists(checkpoint_path):
print(f"Error: Checkpoint '{checkpoint_path}' not found")
@@ -690,20 +690,37 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path):
model.load_state_dict(checkpoint['model_state'])
model.eval()
- # Load image [0,1]
+ # Load image
print(f"Loading input image: {input_path}")
img = Image.open(input_path).convert('RGBA')
- img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+ W, H = img.size
- # Inference
- print("Running inference...")
- with torch.no_grad():
- out = model(img_tensor) # [1,3,H,W] in [0,1]
+ # Tile into patches
+ print(f"Processing {patch_size}×{patch_size} patches...")
+ output = np.zeros((H, W, 3), dtype=np.float32)
+
+ for y in range(0, H, patch_size):
+ for x in range(0, W, patch_size):
+ x_end = min(x + patch_size, W)
+ y_end = min(y + patch_size, H)
+
+ # Extract patch
+ patch = img.crop((x, y, x_end, y_end))
+ patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w]
+
+ # Inference
+ with torch.no_grad():
+ out_patch = model(patch_tensor) # [1,3,h,w]
+
+ # Write to output
+ out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy()
+ output[y:y_end, x:x_end] = out_np
# Save
print(f"Saving output to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
- transforms.ToPILImage()(out.squeeze(0)).save(output_path)
+ output_img = Image.fromarray((output * 255).astype(np.uint8))
+ output_img.save(output_path)
print("Done!")
@@ -736,7 +753,8 @@ def main():
print("Error: --infer requires --export-only <checkpoint>")
sys.exit(1)
output_path = args.output or 'inference_output.png'
- infer_from_checkpoint(checkpoint, args.infer, output_path)
+ patch_size = args.patch_size or 32
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size)
return
# Export-only mode