From eaf0bd855306e70ca03f2d6579b4d6551aff6482 Mon Sep 17 00:00:00 2001 From: skal Date: Thu, 12 Feb 2026 12:11:53 +0100 Subject: TODO: 8-bit weight quantization for 2× size reduction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add QAT (quantization-aware training) notes - Requires training with fake quantization - Target: ~1.6 KB weights (vs 3.2 KB f16) - Shader unpacking needs adaptation (4× u8 per u32) --- checkpoints/checkpoint_epoch_85.pth | Bin 0 -> 24343 bytes training/export_cnn_v2_weights.py | 2 ++ training/train_cnn_v2.py | 8 +++++++- workspaces/main/shaders/cnn_v2_compute.wgsl | 1 + 4 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 checkpoints/checkpoint_epoch_85.pth diff --git a/checkpoints/checkpoint_epoch_85.pth b/checkpoints/checkpoint_epoch_85.pth new file mode 100644 index 0000000..57f8ae6 Binary files /dev/null and b/checkpoints/checkpoint_epoch_85.pth differ diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index e3d1724..723f572 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -94,6 +94,8 @@ def export_weights_binary(checkpoint_path, output_path): weight_offset += len(layer2_flat) # Convert to f16 + # TODO: Use 8-bit quantization for 2× size reduction + # Requires quantization-aware training (QAT) to maintain accuracy all_weights_f16 = np.array(all_weights, dtype=np.float16) # Pack f16 pairs into u32 for storage buffer diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 5c93f20..8cac51a 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -52,7 +52,13 @@ def compute_static_features(rgb, depth=None): class CNNv2(nn.Module): - """CNN v2 with parametric static features.""" + """CNN v2 with parametric static features. + + TODO: Add quantization-aware training (QAT) for 8-bit weights + - Use torch.quantization.QuantStub/DeQuantStub + - Train with fake quantization to adapt to 8-bit precision + - Target: ~1.6 KB weights (vs 3.2 KB with f16) + """ def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): super().__init__() diff --git a/workspaces/main/shaders/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2_compute.wgsl index f9eb556..b19a692 100644 --- a/workspaces/main/shaders/cnn_v2_compute.wgsl +++ b/workspaces/main/shaders/cnn_v2_compute.wgsl @@ -46,6 +46,7 @@ fn pack_channels(values: array) -> vec4 { // Get weight from storage buffer (f16 packed as u32 pairs) // Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16] +// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction fn get_weight(idx: u32) -> f32 { // Skip header (16 bytes = 4 u32) and layer info // Weights start after header + layer_info, but weight_offset already accounts for this -- cgit v1.2.3