summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-12 12:11:53 +0100
committerskal <pascal.massimino@gmail.com>2026-02-12 12:11:53 +0100
commiteaf0bd855306e70ca03f2d6579b4d6551aff6482 (patch)
tree62316af1143db1e59e1ad62e70b9844e324cda55
parente8344bc84ec0f571e5c5aafffe7c914abe226bd6 (diff)
TODO: 8-bit weight quantization for 2× size reduction
- Add QAT (quantization-aware training) notes - Requires training with fake quantization - Target: ~1.6 KB weights (vs 3.2 KB f16) - Shader unpacking needs adaptation (4× u8 per u32)
-rw-r--r--checkpoints/checkpoint_epoch_85.pthbin0 -> 24343 bytes
-rwxr-xr-xtraining/export_cnn_v2_weights.py2
-rwxr-xr-xtraining/train_cnn_v2.py8
-rw-r--r--workspaces/main/shaders/cnn_v2_compute.wgsl1
4 files changed, 10 insertions, 1 deletions
diff --git a/checkpoints/checkpoint_epoch_85.pth b/checkpoints/checkpoint_epoch_85.pth
new file mode 100644
index 0000000..57f8ae6
--- /dev/null
+++ b/checkpoints/checkpoint_epoch_85.pth
Binary files differ
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index e3d1724..723f572 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -94,6 +94,8 @@ def export_weights_binary(checkpoint_path, output_path):
weight_offset += len(layer2_flat)
# Convert to f16
+ # TODO: Use 8-bit quantization for 2× size reduction
+ # Requires quantization-aware training (QAT) to maintain accuracy
all_weights_f16 = np.array(all_weights, dtype=np.float16)
# Pack f16 pairs into u32 for storage buffer
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 5c93f20..8cac51a 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -52,7 +52,13 @@ def compute_static_features(rgb, depth=None):
class CNNv2(nn.Module):
- """CNN v2 with parametric static features."""
+ """CNN v2 with parametric static features.
+
+ TODO: Add quantization-aware training (QAT) for 8-bit weights
+ - Use torch.quantization.QuantStub/DeQuantStub
+ - Train with fake quantization to adapt to 8-bit precision
+ - Target: ~1.6 KB weights (vs 3.2 KB with f16)
+ """
def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]):
super().__init__()
diff --git a/workspaces/main/shaders/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2_compute.wgsl
index f9eb556..b19a692 100644
--- a/workspaces/main/shaders/cnn_v2_compute.wgsl
+++ b/workspaces/main/shaders/cnn_v2_compute.wgsl
@@ -46,6 +46,7 @@ fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
// Get weight from storage buffer (f16 packed as u32 pairs)
// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
+// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction
fn get_weight(idx: u32) -> f32 {
// Skip header (16 bytes = 4 u32) and layer info
// Weights start after header + layer_info, but weight_offset already accounts for this