summaryrefslogtreecommitdiff
path: root/doc/CNN_BIAS_FIX_2026-02.md
blob: 26db8eb98509c150d08179affa4e4f9adcfd84b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# CNN Bias Accumulation Fix (2026-02-11)

## Problem
Bias was being added multiple times in shader convolution loops (once per kernel position), causing mismatch between PyTorch training and WGSL inference.

## Root Cause
**Location**: `training/train_cnn.py:381, 398`

When exporting weights to WGSL, bias was replicated for every kernel position. The shader loops through positions doing:
```wgsl
sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);  // in1.w = 1.0
```

For 3×3 kernel (9 positions), bias added 9×. For 5×5, added 25×.

## Fix
Divide bias by `num_positions` during export:
```python
# Final layer (7→1)
v1.append(f"{bias[0] / num_positions:.6f}")

# Inner layers (7→4)
v1.append(f"{bias[out_c] / num_positions:.6f}")
```

Shader accumulates bias × num_positions = original bias (correct).

---

## Additional Improvements

### 1. RGBA Output Support
**train_cnn.py**: Now saves 4-channel RGBA PNG preserving alpha from input:
```python
alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy()
output_rgba = np.concatenate([output, alpha], axis=2)
Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
```

Intermediate layers also save RGBA if 4-channel.

### 2. Debug Hex Output
**Both tools** support `--debug-hex` to print first 8 pixels as hex:
```bash
./training/train_cnn.py --infer input.png --export-only checkpoint.pth --debug-hex
./build/cnn_test input.png output.png --debug-hex
```

Output format: `[0] 0xRRGGBBAA` for pixel-level comparison.

### 3. Cleanup
Removed sRGB/linear_png debug code from `cnn_test.cc` (simplified PNG saving).

---

## Files Modified
- `training/train_cnn.py`: Bias fix, RGBA output, --debug-hex
- `tools/cnn_test.cc`: --debug-hex, remove linear_png
- `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl`: Regenerated with fixed bias

## Testing
```bash
# Train with fixed export
./training/train_cnn.py --input training/input/ --target training/output/ \
  --layers 3 --kernel_sizes 3,3,3 --epochs 5000

# Generate ground truth
./training/train_cnn.py --infer input.png --export-only checkpoint.pth \
  --output ground_truth.png --debug-hex

# Run GPU tool
./build/cnn_test input.png tool_output.png --debug-hex

# Compare hex output for first 8 pixels
```

---

## Status
✅ Bias accumulation bug fixed
✅ RGBA output with alpha preservation
✅ Debug hex comparison tool
✅ Weights regenerated

Commit: `8ff8c56`