From 5cc6da3831d5bce35af353c14f15c30dbc66b081 Mon Sep 17 00:00:00 2001 From: skal Date: Wed, 11 Feb 2026 09:27:06 +0100 Subject: fix: CNN training/inference to match WGSL sliding window Training now computes loss only on center pixels (excludes conv padding borders). Inference changed from tiling to full-image sliding window. Both match cnn_layer.wgsl: each pixel processed from NxN neighborhood. Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) (limited to 'training') diff --git a/training/train_cnn.py b/training/train_cnn.py index ea1c4d7..89c50d5 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -536,8 +536,13 @@ def train(args): else: print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") + # Compute valid center region (exclude conv padding borders) + num_layers = args.layers + border = num_layers # Each 3x3 layer needs 1px, accumulates across layers + # Training loop print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") + print(f"Computing loss on center region only (excluding {border}px border)") for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): @@ -545,7 +550,15 @@ def train(args): optimizer.zero_grad() outputs = model(inputs) - loss = criterion(outputs, targets) + + # Only compute loss on center pixels with valid neighborhoods + if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border: + outputs_center = outputs[:, :, border:-border, border:-border] + targets_center = targets[:, :, border:-border, border:-border] + loss = criterion(outputs_center, targets_center) + else: + loss = criterion(outputs, targets) + loss.backward() optimizer.step() @@ -666,7 +679,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None): def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32): - """Run patch-based inference to match training distribution""" + """Run sliding-window inference to match WGSL shader behavior""" if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -690,32 +703,20 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 # Load image print(f"Loading input image: {input_path}") img = Image.open(input_path).convert('RGBA') + img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] W, H = img.size - # Tile into patches - print(f"Processing {patch_size}×{patch_size} patches...") - output = np.zeros((H, W, 3), dtype=np.float32) + # Process full image with sliding window (matches WGSL shader) + print(f"Processing full image ({W}×{H}) with sliding window...") + with torch.no_grad(): + output_tensor = model(img_tensor) # [1,3,H,W] - for y in range(0, H, patch_size): - for x in range(0, W, patch_size): - x_end = min(x + patch_size, W) - y_end = min(y + patch_size, H) - - # Extract patch - patch = img.crop((x, y, x_end, y_end)) - patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w] - - # Inference - with torch.no_grad(): - out_patch = model(patch_tensor) # [1,3,h,w] - - # Write to output - out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy() - output[y:y_end, x:x_end] = out_np + # Convert to numpy + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # Save print(f"Saving output to: {output_path}") - os.makedirs(os.path.dirname(output_path), exist_ok=True) + os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) output_img = Image.fromarray((output * 255).astype(np.uint8)) output_img.save(output_path) print("Done!") -- cgit v1.2.3