diff options
Diffstat (limited to 'doc')
| -rw-r--r-- | doc/CNN_BIAS_FIX_2026-02.md | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/doc/CNN_BIAS_FIX_2026-02.md b/doc/CNN_BIAS_FIX_2026-02.md new file mode 100644 index 0000000..26db8eb --- /dev/null +++ b/doc/CNN_BIAS_FIX_2026-02.md @@ -0,0 +1,85 @@ +# CNN Bias Accumulation Fix (2026-02-11) + +## Problem +Bias was being added multiple times in shader convolution loops (once per kernel position), causing mismatch between PyTorch training and WGSL inference. + +## Root Cause +**Location**: `training/train_cnn.py:381, 398` + +When exporting weights to WGSL, bias was replicated for every kernel position. The shader loops through positions doing: +```wgsl +sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); // in1.w = 1.0 +``` + +For 3×3 kernel (9 positions), bias added 9×. For 5×5, added 25×. + +## Fix +Divide bias by `num_positions` during export: +```python +# Final layer (7→1) +v1.append(f"{bias[0] / num_positions:.6f}") + +# Inner layers (7→4) +v1.append(f"{bias[out_c] / num_positions:.6f}") +``` + +Shader accumulates bias × num_positions = original bias (correct). + +--- + +## Additional Improvements + +### 1. RGBA Output Support +**train_cnn.py**: Now saves 4-channel RGBA PNG preserving alpha from input: +```python +alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() +output_rgba = np.concatenate([output, alpha], axis=2) +Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') +``` + +Intermediate layers also save RGBA if 4-channel. + +### 2. Debug Hex Output +**Both tools** support `--debug-hex` to print first 8 pixels as hex: +```bash +./training/train_cnn.py --infer input.png --export-only checkpoint.pth --debug-hex +./build/cnn_test input.png output.png --debug-hex +``` + +Output format: `[0] 0xRRGGBBAA` for pixel-level comparison. + +### 3. Cleanup +Removed sRGB/linear_png debug code from `cnn_test.cc` (simplified PNG saving). + +--- + +## Files Modified +- `training/train_cnn.py`: Bias fix, RGBA output, --debug-hex +- `tools/cnn_test.cc`: --debug-hex, remove linear_png +- `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl`: Regenerated with fixed bias + +## Testing +```bash +# Train with fixed export +./training/train_cnn.py --input training/input/ --target training/output/ \ + --layers 3 --kernel_sizes 3,3,3 --epochs 5000 + +# Generate ground truth +./training/train_cnn.py --infer input.png --export-only checkpoint.pth \ + --output ground_truth.png --debug-hex + +# Run GPU tool +./build/cnn_test input.png tool_output.png --debug-hex + +# Compare hex output for first 8 pixels +``` + +--- + +## Status +✅ Bias accumulation bug fixed +✅ RGBA output with alpha preservation +✅ Debug hex comparison tool +✅ Weights regenerated + +Commit: `8ff8c56` |
