# CNN Bias Accumulation Fix (2026-02-11) ## Problem Bias was being added multiple times in shader convolution loops (once per kernel position), causing mismatch between PyTorch training and WGSL inference. ## Root Cause **Location**: `training/train_cnn.py:381, 398` When exporting weights to WGSL, bias was replicated for every kernel position. The shader loops through positions doing: ```wgsl sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); // in1.w = 1.0 ``` For 3×3 kernel (9 positions), bias added 9×. For 5×5, added 25×. ## Fix Divide bias by `num_positions` during export: ```python # Final layer (7→1) v1.append(f"{bias[0] / num_positions:.6f}") # Inner layers (7→4) v1.append(f"{bias[out_c] / num_positions:.6f}") ``` Shader accumulates bias × num_positions = original bias (correct). --- ## Additional Improvements ### 1. RGBA Output Support **train_cnn.py**: Now saves 4-channel RGBA PNG preserving alpha from input: ```python alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() output_rgba = np.concatenate([output, alpha], axis=2) Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') ``` Intermediate layers also save RGBA if 4-channel. ### 2. Debug Hex Output **Both tools** support `--debug-hex` to print first 8 pixels as hex: ```bash ./training/train_cnn.py --infer input.png --export-only checkpoint.pth --debug-hex ./build/cnn_test input.png output.png --debug-hex ``` Output format: `[0] 0xRRGGBBAA` for pixel-level comparison. ### 3. Cleanup Removed sRGB/linear_png debug code from `cnn_test.cc` (simplified PNG saving). --- ## Files Modified - `training/train_cnn.py`: Bias fix, RGBA output, --debug-hex - `tools/cnn_test.cc`: --debug-hex, remove linear_png - `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl`: Regenerated with fixed bias ## Testing ```bash # Train with fixed export ./training/train_cnn.py --input training/input/ --target training/output/ \ --layers 3 --kernel_sizes 3,3,3 --epochs 5000 # Generate ground truth ./training/train_cnn.py --infer input.png --export-only checkpoint.pth \ --output ground_truth.png --debug-hex # Run GPU tool ./build/cnn_test input.png tool_output.png --debug-hex # Compare hex output for first 8 pixels ``` --- ## Status ✅ Bias accumulation bug fixed ✅ RGBA output with alpha preservation ✅ Debug hex comparison tool ✅ Weights regenerated Commit: `8ff8c56`