diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-12 12:11:53 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-12 12:11:53 +0100 |
| commit | eaf0bd855306e70ca03f2d6579b4d6551aff6482 (patch) | |
| tree | 62316af1143db1e59e1ad62e70b9844e324cda55 /training | |
| parent | e8344bc84ec0f571e5c5aafffe7c914abe226bd6 (diff) | |
TODO: 8-bit weight quantization for 2× size reduction
- Add QAT (quantization-aware training) notes
- Requires training with fake quantization
- Target: ~1.6 KB weights (vs 3.2 KB f16)
- Shader unpacking needs adaptation (4× u8 per u32)
Diffstat (limited to 'training')
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 2 | ||||
| -rwxr-xr-x | training/train_cnn_v2.py | 8 |
2 files changed, 9 insertions, 1 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index e3d1724..723f572 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -94,6 +94,8 @@ def export_weights_binary(checkpoint_path, output_path): weight_offset += len(layer2_flat) # Convert to f16 + # TODO: Use 8-bit quantization for 2× size reduction + # Requires quantization-aware training (QAT) to maintain accuracy all_weights_f16 = np.array(all_weights, dtype=np.float16) # Pack f16 pairs into u32 for storage buffer diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 5c93f20..8cac51a 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -52,7 +52,13 @@ def compute_static_features(rgb, depth=None): class CNNv2(nn.Module): - """CNN v2 with parametric static features.""" + """CNN v2 with parametric static features. + + TODO: Add quantization-aware training (QAT) for 8-bit weights + - Use torch.quantization.QuantStub/DeQuantStub + - Train with fake quantization to adapt to 8-bit precision + - Target: ~1.6 KB weights (vs 3.2 KB with f16) + """ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): super().__init__() |
