summaryrefslogtreecommitdiff
path: root/training/train_cnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/train_cnn.py')
-rwxr-xr-xtraining/train_cnn.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py
index c775325..4171dcb 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -218,7 +218,10 @@ class PatchDataset(Dataset):
class SimpleCNN(nn.Module):
- """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)"""
+ """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray)
+
+ Internally computes grayscale, expands to 3-channel RGB output.
+ """
def __init__(self, num_layers=1, kernel_sizes=None):
super(SimpleCNN, self).__init__()
@@ -272,11 +275,11 @@ class SimpleCNN(nn.Module):
if return_intermediates:
intermediates.append(out.clone())
- # Final layer (grayscale output)
+ # Final layer (grayscale→RGB)
final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
- out = self.layers[-1](final_input) # [B,1,H,W]
+ out = self.layers[-1](final_input) # [B,1,H,W] grayscale
out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
- final_out = out.expand(-1, 3, -1, -1)
+ final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB
if return_intermediates:
return final_out, intermediates
@@ -777,7 +780,10 @@ def export_from_checkpoint(checkpoint_path, output_path=None):
def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
- """Run sliding-window inference to match WGSL shader behavior"""
+ """Run sliding-window inference to match WGSL shader behavior
+
+ Outputs RGBA PNG (RGB from model + alpha from input).
+ """
if not os.path.exists(checkpoint_path):
print(f"Error: Checkpoint '{checkpoint_path}' not found")
@@ -819,14 +825,12 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3
if save_intermediates:
output_tensor, intermediates = model(img_tensor, return_intermediates=True)
else:
- output_tensor = model(img_tensor) # [1,3,H,W]
-
- # Convert to numpy
- output = output_tensor.squeeze(0).permute(1, 2, 0).numpy()
+ output_tensor = model(img_tensor) # [1,3,H,W] RGB
- # Append alpha channel from input
- alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1]
- output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4]
+ # Convert to numpy and append alpha
+ output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB
+ alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input
+ output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA
# Debug: print first 8 pixels as hex
if debug_hex: