diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index c775325..4171dcb 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -218,7 +218,10 @@ class PatchDataset(Dataset): class SimpleCNN(nn.Module): - """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" + """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray) + + Internally computes grayscale, expands to 3-channel RGB output. + """ def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() @@ -272,11 +275,11 @@ class SimpleCNN(nn.Module): if return_intermediates: intermediates.append(out.clone()) - # Final layer (grayscale output) + # Final layer (grayscale→RGB) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) - out = self.layers[-1](final_input) # [B,1,H,W] + out = self.layers[-1](final_input) # [B,1,H,W] grayscale out = torch.sigmoid(out) # Map to [0,1] with smooth gradients - final_out = out.expand(-1, 3, -1, -1) + final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB if return_intermediates: return final_out, intermediates @@ -777,7 +780,10 @@ def export_from_checkpoint(checkpoint_path, output_path=None): def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False): - """Run sliding-window inference to match WGSL shader behavior""" + """Run sliding-window inference to match WGSL shader behavior + + Outputs RGBA PNG (RGB from model + alpha from input). + """ if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -819,14 +825,12 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 if save_intermediates: output_tensor, intermediates = model(img_tensor, return_intermediates=True) else: - output_tensor = model(img_tensor) # [1,3,H,W] - - # Convert to numpy - output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() + output_tensor = model(img_tensor) # [1,3,H,W] RGB - # Append alpha channel from input - alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] - output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] + # Convert to numpy and append alpha + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB + alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input + output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA # Debug: print first 8 pixels as hex if debug_hex: |
