diff options
Diffstat (limited to 'training/train_cnn_v2.py')
| -rwxr-xr-x | training/train_cnn_v2.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 5c93f20..8cac51a 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -52,7 +52,13 @@ def compute_static_features(rgb, depth=None): class CNNv2(nn.Module): - """CNN v2 with parametric static features.""" + """CNN v2 with parametric static features. + + TODO: Add quantization-aware training (QAT) for 8-bit weights + - Use torch.quantization.QuantStub/DeQuantStub + - Train with fake quantization to adapt to 8-bit precision + - Target: ~1.6 KB weights (vs 3.2 KB with f16) + """ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): super().__init__() |
