summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rwxr-xr-xtraining/train_cnn.py45
1 files changed, 23 insertions, 22 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index ea1c4d7..89c50d5 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -536,8 +536,13 @@ def train(args):
else:
print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
+ # Compute valid center region (exclude conv padding borders)
+ num_layers = args.layers
+ border = num_layers # Each 3x3 layer needs 1px, accumulates across layers
+
# Training loop
print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
+ print(f"Computing loss on center region only (excluding {border}px border)")
for epoch in range(start_epoch, args.epochs):
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
@@ -545,7 +550,15 @@ def train(args):
optimizer.zero_grad()
outputs = model(inputs)
- loss = criterion(outputs, targets)
+
+ # Only compute loss on center pixels with valid neighborhoods
+ if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border:
+ outputs_center = outputs[:, :, border:-border, border:-border]
+ targets_center = targets[:, :, border:-border, border:-border]
+ loss = criterion(outputs_center, targets_center)
+ else:
+ loss = criterion(outputs, targets)
+
loss.backward()
optimizer.step()
@@ -666,7 +679,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32):
- """Run patch-based inference to match training distribution"""
+ """Run sliding-window inference to match WGSL shader behavior"""
if not os.path.exists(checkpoint_path):
print(f"Error: Checkpoint '{checkpoint_path}' not found")
@@ -690,32 +703,20 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
# Load image
print(f"Loading input image: {input_path}")
img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
W, H = img.size
- # Tile into patches
- print(f"Processing {patch_size}×{patch_size} patches...")
- output = np.zeros((H, W, 3), dtype=np.float32)
+ # Process full image with sliding window (matches WGSL shader)
+ print(f"Processing full image ({W}×{H}) with sliding window...")
+ with torch.no_grad():
+ output_tensor = model(img_tensor) # [1,3,H,W]
- for y in range(0, H, patch_size):
- for x in range(0, W, patch_size):
- x_end = min(x + patch_size, W)
- y_end = min(y + patch_size, H)
-
- # Extract patch
- patch = img.crop((x, y, x_end, y_end))
- patch_tensor = transforms.ToTensor()(patch).unsqueeze(0) # [1,4,h,w]
-
- # Inference
- with torch.no_grad():
- out_patch = model(patch_tensor) # [1,3,h,w]
-
- # Write to output
- out_np = out_patch.squeeze(0).permute(1, 2, 0).numpy()
- output[y:y_end, x:x_end] = out_np
+ # Convert to numpy
+ output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
# Save
print(f"Saving output to: {output_path}")
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
output_img = Image.fromarray((output * 255).astype(np.uint8))
output_img.save(output_path)
print("Done!")