diff options
| -rw-r--r-- | doc/HOWTO.md | 15 | ||||
| -rwxr-xr-x | training/train_cnn.py | 45 |
2 files changed, 34 insertions, 26 deletions
diff --git a/doc/HOWTO.md b/doc/HOWTO.md index db324ec..7b0daa0 100644 --- a/doc/HOWTO.md +++ b/doc/HOWTO.md @@ -89,7 +89,7 @@ make run_util_tests # Utility tests ## Training ### Patch-Based (Recommended) -Extracts patches at salient points, preserves natural pixel scale: +Extracts patches at salient points, trains on center pixels only (matches WGSL sliding window): ```bash # Train with 32×32 patches at detected corners/edges ./training/train_cnn.py \ @@ -99,10 +99,15 @@ Extracts patches at salient points, preserves natural pixel scale: --checkpoint-every 1000 ``` +**Training behavior:** +- Loss computed only on center pixels (excludes conv padding borders) +- For 3-layer network: excludes 3px border on each side +- Matches GPU shader sliding-window paradigm + **Detectors:** `harris` (default), `fast`, `shi-tomasi`, `gradient` -### Full-Image (Legacy) -Resizes to 256×256 (distorts scale): +### Full-Image +Processes entire image with sliding window (matches WGSL): ```bash ./training/train_cnn.py \ --input training/input/ --target training/output/ \ @@ -115,12 +120,14 @@ Resizes to 256×256 (distorts scale): # Generate shaders from checkpoint ./training/train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth -# Generate ground truth for comparison +# Generate ground truth (sliding window, no tiling) ./training/train_cnn.py --infer input.png \ --export-only checkpoints/checkpoint_epoch_5000.pth \ --output ground_truth.png ``` +**Inference:** Processes full image with sliding window (each pixel from NxN neighborhood). No tiling artifacts. + **Kernel sizes:** 3×3 (36 weights), 5×5 (100 weights), 7×7 (196 weights) --- diff --git a/training/train_cnn.py b/training/train_cnn.py index ea1c4d7..89c50d5 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -536,8 +536,13 @@ def train(args): else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") + # Compute valid center region (exclude conv padding borders) + num_layers = args.layers + border = num_layers # Each 3x3 layer needs 1px, accumulates across layers + # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") + print(f"Computing loss on center region only (excluding {border}px border)") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): @@ -545,7 +550,15 @@ def train(args): optimizer.zero_grad() outputs = model(inputs) - loss = criterion(outputs, targets) + + # Only compute loss on center pixels with valid neighborhoods + if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border: + outputs_center = outputs[:, :, border:-border, border:-border] + targets_center = targets[:, :, border:-border, border:-border] + loss = criterion(outputs_center, targets_center) + else: + loss = criterion(outputs, targets) + loss.backward() optimizer.step() @@ -666,7 +679,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None): def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): - """Run patch-based inference to match training distribution""" + """Run sliding-window inference to match WGSL shader behavior""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -690,32 +703,20 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 # Load image print(f"Loading input image: {input_path}") img = Image.open(input_path).convert('RGBA') + img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] W, H = img.size - # Tile into patches - print(f"Processing {patch_size}×{patch_size} patches...") - output = np.zeros((H, W, 3), dtype=np.float32) + # Process full image with sliding window (matches WGSL shader) + print(f"Processing full image ({W}×{H}) with sliding window...") + with torch.no_grad(): + output_tensor = model(img_tensor) # [1,3,H,W] - for y in range(0, H, patch_size): - for x in range(0, W, patch_size): - x_end = min(x + patch_size, W) - y_end = min(y + patch_size, H) - - # Extract patch - patch = img.crop((x, y, x_end, y_end)) - patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w] - - # Inference - with torch.no_grad(): - out_patch = model(patch_tensor) # [1,3,h,w] - - # Write to output - out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy() - output[y:y_end, x:x_end] = out_np + # Convert to numpy + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # Save print(f"Saving output to: {output_path}") - os.makedirs(os.path.dirname(output_path), exist_ok=True) + os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) output_img = Image.fromarray((output * 255).astype(np.uint8)) output_img.save(output_path) print("Done!") |
