diff options
Diffstat (limited to 'training/train_cnn.py')
| -rwxr-xr-x | training/train_cnn.py | 77 |
1 files changed, 60 insertions, 17 deletions
diff --git a/training/train_cnn.py b/training/train_cnn.py index 1ea42a3..4171dcb 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -218,7 +218,10 @@ class PatchDataset(Dataset): class SimpleCNN(nn.Module): - """CNN for RGBD→grayscale with 7-channel input (RGBD + UV + gray)""" + """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray) + + Internally computes grayscale, expands to 3-channel RGB output. + """ def __init__(self, num_layers=1, kernel_sizes=None): super(SimpleCNN, self).__init__() @@ -272,11 +275,11 @@ class SimpleCNN(nn.Module): if return_intermediates: intermediates.append(out.clone()) - # Final layer (grayscale output) + # Final layer (grayscale→RGB) final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) - out = self.layers[-1](final_input) # [B,1,H,W] + out = self.layers[-1](final_input) # [B,1,H,W] grayscale out = torch.sigmoid(out) # Map to [0,1] with smooth gradients - final_out = out.expand(-1, 3, -1, -1) + final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB if return_intermediates: return final_out, intermediates @@ -378,7 +381,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[0]:.6f}") + v1.append(f"{bias[0] / num_positions:.6f}") f.write(f" vec4<f32>({', '.join(v0)}),\n") f.write(f" vec4<f32>({', '.join(v1)})") f.write(",\n" if pos < num_positions-1 else "\n") @@ -395,7 +398,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[out_c]:.6f}") + v1.append(f"{bias[out_c] / num_positions:.6f}") idx = (pos * 4 + out_c) * 2 f.write(f" vec4<f32>({', '.join(v0)}),\n") f.write(f" vec4<f32>({', '.join(v1)})") @@ -776,8 +779,11 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None): - """Run sliding-window inference to match WGSL shader behavior""" +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False): + """Run sliding-window inference to match WGSL shader behavior + + Outputs RGBA PNG (RGB from model + alpha from input). + """ if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint '{checkpoint_path}' not found") @@ -796,6 +802,15 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 kernel_sizes=checkpoint['kernel_sizes'] ) model.load_state_dict(checkpoint['model_state']) + + # Debug: Zero out all weights and biases + if zero_weights: + print("DEBUG: Zeroing out all weights and biases") + for layer in model.layers: + with torch.no_grad(): + layer.weight.zero_() + layer.bias.zero_() + model.eval() # Load image @@ -810,15 +825,26 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 if save_intermediates: output_tensor, intermediates = model(img_tensor, return_intermediates=True) else: - output_tensor = model(img_tensor) # [1,3,H,W] + output_tensor = model(img_tensor) # [1,3,H,W] RGB - # Convert to numpy - output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() + # Convert to numpy and append alpha + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB + alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input + output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA - # Save final output + # Debug: print first 8 pixels as hex + if debug_hex: + output_u8 = (output_rgba * 255).astype(np.uint8) + print("First 8 pixels (RGBA hex):") + for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])): + y, x = i // output_u8.shape[1], i % output_u8.shape[1] + r, g, b, a = output_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + + # Save final output as RGBA print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) - output_img = Image.fromarray((output * 255).astype(np.uint8)) + output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') output_img.save(output_path) # Save intermediates if requested @@ -828,10 +854,25 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 for layer_idx, layer_tensor in enumerate(intermediates): # Convert [-1,1] to [0,1] for visualization layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 - # Take first channel for 4-channel intermediate layers + layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8) + + # Debug: print first 8 pixels as hex + if debug_hex: + print(f"Layer {layer_idx} first 8 pixels (RGBA hex):") + for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])): + y, x = i // layer_u8.shape[1], i % layer_u8.shape[1] + if layer_u8.shape[2] == 4: + r, g, b, a = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + else: + r, g, b = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}") + + # Save all 4 channels for intermediate layers if layer_data.shape[2] == 4: - layer_data = layer_data[:, :, :3] # Show RGB only - layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8)) + layer_img = Image.fromarray(layer_u8, mode='RGBA') + else: + layer_img = Image.fromarray(layer_u8) layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') layer_img.save(layer_path) print(f" Saved layer {layer_idx} to {layer_path}") @@ -861,6 +902,8 @@ def main(): parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') + parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)') + parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)') args = parser.parse_args() @@ -872,7 +915,7 @@ def main(): sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 - infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates) + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex) return # Export-only mode |
