From 8ff8c56cd68d9e785cf6cb36ce1fc2bdc54ac15a Mon Sep 17 00:00:00 2001 From: skal Date: Wed, 11 Feb 2026 23:13:43 +0100 Subject: fix: CNN bias accumulation and output format improvements - Fix bias division bug: divide by num_positions to compensate for shader loop accumulation (affects all layers) - train_cnn.py: Save RGBA output preserving alpha channel from input - Add --debug-hex flag to both tools for pixel-level debugging - Remove sRGB/linear_png debug code from cnn_test - Regenerate weights with corrected bias export Co-Authored-By: Claude Sonnet 4.5 --- training/train_cnn.py | 57 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 9 deletions(-) (limited to 'training/train_cnn.py') diff --git a/training/train_cnn.py b/training/train_cnn.py index 1ea42a3..c775325 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -378,7 +378,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[0]:.6f}") + v1.append(f"{bias[0] / num_positions:.6f}") f.write(f" vec4({', '.join(v0)}),\n") f.write(f" vec4({', '.join(v1)})") f.write(",\n" if pos < num_positions-1 else "\n") @@ -395,7 +395,7 @@ def export_weights_to_wgsl(model, output_path, kernel_sizes): v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)] # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)] - v1.append(f"{bias[out_c]:.6f}") + v1.append(f"{bias[out_c] / num_positions:.6f}") idx = (pos * 4 + out_c) * 2 f.write(f" vec4({', '.join(v0)}),\n") f.write(f" vec4({', '.join(v1)})") @@ -776,7 +776,7 @@ def export_from_checkpoint(checkpoint_path, output_path=None): print("Export complete!") -def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None): +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False): """Run sliding-window inference to match WGSL shader behavior""" if not os.path.exists(checkpoint_path): @@ -796,6 +796,15 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 kernel_sizes=checkpoint['kernel_sizes'] ) model.load_state_dict(checkpoint['model_state']) + + # Debug: Zero out all weights and biases + if zero_weights: + print("DEBUG: Zeroing out all weights and biases") + for layer in model.layers: + with torch.no_grad(): + layer.weight.zero_() + layer.bias.zero_() + model.eval() # Load image @@ -815,10 +824,23 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 # Convert to numpy output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() - # Save final output + # Append alpha channel from input + alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] + output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] + + # Debug: print first 8 pixels as hex + if debug_hex: + output_u8 = (output_rgba * 255).astype(np.uint8) + print("First 8 pixels (RGBA hex):") + for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])): + y, x = i // output_u8.shape[1], i % output_u8.shape[1] + r, g, b, a = output_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + + # Save final output as RGBA print(f"Saving output to: {output_path}") os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) - output_img = Image.fromarray((output * 255).astype(np.uint8)) + output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') output_img.save(output_path) # Save intermediates if requested @@ -828,10 +850,25 @@ def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=3 for layer_idx, layer_tensor in enumerate(intermediates): # Convert [-1,1] to [0,1] for visualization layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 - # Take first channel for 4-channel intermediate layers + layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8) + + # Debug: print first 8 pixels as hex + if debug_hex: + print(f"Layer {layer_idx} first 8 pixels (RGBA hex):") + for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])): + y, x = i // layer_u8.shape[1], i % layer_u8.shape[1] + if layer_u8.shape[2] == 4: + r, g, b, a = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + else: + r, g, b = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}") + + # Save all 4 channels for intermediate layers if layer_data.shape[2] == 4: - layer_data = layer_data[:, :, :3] # Show RGB only - layer_img = Image.fromarray((layer_data.clip(0, 1) * 255).astype(np.uint8)) + layer_img = Image.fromarray(layer_u8, mode='RGBA') + else: + layer_img = Image.fromarray(layer_u8) layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') layer_img.save(layer_path) print(f" Saved layer {layer_idx} to {layer_path}") @@ -861,6 +898,8 @@ def main(): parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') + parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)') + parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)') args = parser.parse_args() @@ -872,7 +911,7 @@ def main(): sys.exit(1) output_path = args.output or 'inference_output.png' patch_size = args.patch_size or 32 - infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates) + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex) return # Export-only mode -- cgit v1.2.3