summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-13 12:32:36 +0100
committerskal <pascal.massimino@gmail.com>2026-02-13 12:32:36 +0100
commit561d1dc446db7d1d3e02b92b43abedf1a5017850 (patch)
treeef9302dc1f9b6b9f8a12225580f2a3b07602656b
parentc27b34279c0d1c2a8f1dbceb0e154b585b5c6916 (diff)
CNN v2: Refactor to uniform 12D→4D architecture
**Architecture changes:** - Static features (8D): p0-p3 (parametric) + uv_x, uv_y, sin(10×uv_x), bias - Input RGBD (4D): fed separately to all layers - All layers: uniform 12D→4D (4 prev/input + 8 static → 4 output) - Bias integrated in static features (bias=False in PyTorch) **Weight calculations:** - 3 layers × (12 × 3×3 × 4) = 1296 weights - f16: 2.6 KB (vs old variable arch: ~6.4 KB) **Updated files:** *Training (Python):* - train_cnn_v2.py: Uniform model, takes input_rgbd + static_features - export_cnn_v2_weights.py: Binary export for storage buffers - export_cnn_v2_shader.py: Per-layer shader export (debugging) *Shaders (WGSL):* - cnn_v2_static.wgsl: p0-p3 parametric features (mips/gradients) - cnn_v2_compute.wgsl: 12D input, 4D output, vec4 packing *Tools:* - HTML tool (cnn_v2_test): Updated for 12D→4D, layer visualization *Docs:* - CNN_V2.md: Updated architecture, training, validation sections - HOWTO.md: Reference HTML tool for validation *Removed:* - validate_cnn_v2.sh: Obsolete (used CNN v1 tool) All code consistent with bias=False (bias in static features as 1.0). handoff(Claude): CNN v2 architecture finalized and documented
-rw-r--r--doc/CNN_V2.md232
-rw-r--r--doc/HOWTO.md7
-rwxr-xr-xscripts/validate_cnn_v2.sh60
-rw-r--r--tools/cnn_v2_test/index.html65
-rwxr-xr-xtraining/export_cnn_v2_shader.py127
-rwxr-xr-xtraining/export_cnn_v2_weights.py85
-rwxr-xr-xtraining/train_cnn_v2.py134
-rw-r--r--workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl80
-rw-r--r--workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl23
9 files changed, 350 insertions, 463 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md
index 588c3db..4612d7a 100644
--- a/doc/CNN_V2.md
+++ b/doc/CNN_V2.md
@@ -40,10 +40,10 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA
- Lifetime: Entire frame (all CNN layer passes)
**CNN Layers:**
-- Input Layer: 7D static features → C₀ channels
-- Inner Layers: (7D + Cᵢ₋₁) → Cᵢ channels
-- Output Layer: (7D + Cₙ) → 4D RGBA
-- Storage: `texture_storage_2d<rgba32uint>` (8×f16 per texel recommended)
+- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels
+- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels
+- All layers: uniform 12D input, 4D output (ping-pong buffer)
+- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs)
---
@@ -54,11 +54,13 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA
**8 float16 values per pixel:**
```wgsl
-// Slot 0-3: RGBD (core pixel data)
-let r = rgba.r; // Red channel
-let g = rgba.g; // Green channel
-let b = rgba.b; // Blue channel
-let d = depth; // Depth value
+// Slot 0-3: Parametric features (p0, p1, p2, p3)
+// Can be: mip1/2 RGBD, grayscale, gradients, etc.
+// Distinct from input image RGBD (fed only to Layer 0)
+let p0 = ...; // Parametric feature 0 (e.g., mip1.r or grayscale)
+let p1 = ...; // Parametric feature 1
+let p2 = ...; // Parametric feature 2
+let p3 = ...; // Parametric feature 3
// Slot 4-5: UV coordinates (normalized screen space)
let uv_x = coord.x / resolution.x; // Horizontal position [0,1]
@@ -70,18 +72,20 @@ let sin10_x = sin(10.0 * uv_x); // Periodic feature (frequency=10)
// Slot 7: Bias dimension (always 1.0)
let bias = 1.0; // Learned bias per output channel
-// Packed storage: [R, G, B, D, uv.x, uv.y, sin(10*uv.x), 1.0]
+// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(10*uv.x), 1.0]
```
### Feature Rationale
| Feature | Dimension | Purpose | Priority |
|---------|-----------|---------|----------|
-| RGBD | 4D | Core pixel information | Essential |
+| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential |
| UV coords | 2D | Spatial position awareness | Essential |
| sin(10\*uv.x) | 1D | Periodic position encoding | Medium |
| Bias | 1D | Learned bias (standard NN) | Essential |
+**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output.
+
**Why bias as static feature:**
- Simpler shader code (single weight array)
- Standard NN formulation: y = Wx (x includes bias term)
@@ -113,68 +117,65 @@ Requires quantization-aware training.
### Example 3-Layer Network
```
-Input: 7D static → 16 channels (1×1 kernel, pointwise)
-Layer1: (7+16)D → 8 channels (3×3 kernel, spatial)
-Layer2: (7+8)D → 4 channels (5×5 kernel, large receptive field)
+Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel)
+Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel)
+Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output)
```
### Weight Calculations
-**Per-layer weights:**
+**Per-layer weights (uniform 12D→4D, 3×3 kernels):**
```
-Input: 7 × 1 × 1 × 16 = 112 weights
-Layer1: (7+16) × 3 × 3 × 8 = 1656 weights
-Layer2: (7+8) × 5 × 5 × 4 = 1500 weights
-Total: 3268 weights
+Layer 0: 12 × 3 × 3 × 4 = 432 weights
+Layer 1: 12 × 3 × 3 × 4 = 432 weights
+Layer 2: 12 × 3 × 3 × 4 = 432 weights
+Total: 1296 weights
```
**Storage sizes:**
-- f32: 3268 × 4 = 13,072 bytes (~12.8 KB)
-- f16: 3268 × 2 = 6,536 bytes (~6.4 KB) ✓ **recommended**
+- f32: 1296 × 4 = 5,184 bytes (~5.1 KB)
+- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended**
**Comparison to v1:**
- v1: ~800 weights (3.2 KB f32)
-- v2: ~3268 weights (6.4 KB f16)
-- **Growth: 2× size for parametric features**
+- v2: ~1296 weights (2.5 KB f16)
+- **Uniform architecture, smaller than v1 f32**
### Kernel Size Guidelines
**1×1 kernel (pointwise):**
- No spatial context, channel mixing only
-- Weights: `(7 + C_in) × C_out`
-- Use for: Input layer, bottleneck layers
+- Weights: `12 × 4 = 48` per layer
+- Use for: Fast inference, channel remapping
**3×3 kernel (standard conv):**
-- Local spatial context
-- Weights: `(7 + C_in) × 9 × C_out`
-- Use for: Most inner layers
+- Local spatial context (recommended)
+- Weights: `12 × 9 × 4 = 432` per layer
+- Use for: Most layers (balanced quality/size)
**5×5 kernel (large receptive field):**
- Wide spatial context
-- Weights: `(7 + C_in) × 25 × C_out`
-- Use for: Output layer, detail enhancement
+- Weights: `12 × 25 × 4 = 1200` per layer
+- Use for: Output layer, fine detail enhancement
-### Channel Storage (8×f16 per texel)
+### Channel Storage (4×f16 per texel)
```wgsl
@group(0) @binding(1) var layer_input: texture_2d<u32>;
-fn unpack_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
- return array(
- unpack2x16float(packed.x).x, unpack2x16float(packed.x).y,
- unpack2x16float(packed.y).x, unpack2x16float(packed.y).y,
- unpack2x16float(packed.z).x, unpack2x16float(packed.z).y,
- unpack2x16float(packed.w).x, unpack2x16float(packed.w).y
- );
+ let v0 = unpack2x16float(packed.x); // [ch0, ch1]
+ let v1 = unpack2x16float(packed.y); // [ch2, ch3]
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
- return vec4(
- pack2x16float(vec2(values[0], values[1])),
- pack2x16float(vec2(values[2], values[3])),
- pack2x16float(vec2(values[4], values[5])),
- pack2x16float(vec2(values[6], values[7]))
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2(values.x, values.y)),
+ pack2x16float(vec2(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}
```
@@ -189,11 +190,11 @@ fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
```python
def compute_static_features(rgb, depth):
- """Generate 7D static features + bias dimension."""
+ """Generate parametric features (8D: p0-p3 + spatial)."""
h, w = rgb.shape[:2]
- # RGBD channels
- r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
+ # Parametric features (example: use input RGBD, but could be mips/gradients)
+ p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth
# UV coordinates (normalized)
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0)
@@ -203,56 +204,51 @@ def compute_static_features(rgb, depth):
sin10_x = np.sin(10.0 * uv_x)
# Bias dimension (always 1.0)
- bias = np.ones_like(r)
+ bias = np.ones_like(p0)
- # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias]
- return np.stack([r, g, b, depth, uv_x, uv_y, sin10_x, bias], axis=-1)
+ # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
+ return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1)
```
**Network Definition:**
```python
class CNNv2(nn.Module):
- def __init__(self, kernels=[1,3,5], channels=[16,8,4]):
+ def __init__(self, kernel_size=3, num_layers=3):
super().__init__()
+ self.layers = nn.ModuleList()
- # Input layer: 8D (7 features + bias) → channels[0]
- self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0],
- padding=kernels[0]//2, bias=False)
-
- # Inner layers: (7 features + bias + C_prev) → C_next
- in_ch_1 = 8 + channels[0] # static + layer0 output
- self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1],
- padding=kernels[1]//2, bias=False)
-
- # Output layer: (7 features + bias + C_last) → 4 (RGBA)
- in_ch_2 = 8 + channels[1]
- self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2],
- padding=kernels[2]//2, bias=False)
+ # All layers: 12D input (4 prev + 8 static) → 4D output
+ for i in range(num_layers):
+ self.layers.append(
+ nn.Conv2d(12, 4, kernel_size=kernel_size,
+ padding=kernel_size//2, bias=False)
+ )
- def forward(self, static_features, layer0_input=None):
- # Layer 0: Use full 8D static features (includes bias)
- x0 = self.layer0(static_features)
- x0 = F.relu(x0)
+ def forward(self, input_rgbd, static_features):
+ # Layer 0: input RGBD (4D) + static (8D) = 12D
+ x = torch.cat([input_rgbd, static_features], dim=1)
+ x = self.layers[0](x)
+ x = torch.clamp(x, 0, 1) # Output layer 0 (4 channels)
- # Layer 1: Concatenate static + layer0 output
- x1_input = torch.cat([static_features, x0], dim=1)
- x1 = self.layer1(x1_input)
- x1 = F.relu(x1)
+ # Layer 1+: previous output (4D) + static (8D) = 12D
+ for i in range(1, len(self.layers)):
+ x_input = torch.cat([x, static_features], dim=1)
+ x = self.layers[i](x_input)
+ if i < len(self.layers) - 1:
+ x = F.relu(x)
+ else:
+ x = torch.clamp(x, 0, 1) # Final output [0,1]
- # Layer 2: Concatenate static + layer1 output
- x2_input = torch.cat([static_features, x1], dim=1)
- output = self.layer2(x2_input)
-
- return torch.sigmoid(output) # RGBA output [0,1]
+ return x # RGBA output
```
**Training Configuration:**
```python
# Hyperparameters
-kernels = [1, 3, 5] # Per-layer kernel sizes
-channels = [16, 8, 4] # Per-layer output channels
+kernel_size = 3 # Uniform 3×3 kernels
+num_layers = 3 # Number of CNN layers
learning_rate = 1e-3
batch_size = 16
epochs = 5000
@@ -260,11 +256,14 @@ epochs = 5000
# Training loop (standard PyTorch f32)
for epoch in range(epochs):
for rgb_batch, depth_batch, target_batch in dataloader:
- # Compute static features
+ # Compute static features (8D)
static_feat = compute_static_features(rgb_batch, depth_batch)
+ # Input RGBD (4D)
+ input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1)
+
# Forward pass
- output = model(static_feat)
+ output = model(input_rgbd, static_feat)
loss = criterion(output, target_batch)
# Backward pass
@@ -279,9 +278,9 @@ for epoch in range(epochs):
torch.save({
'state_dict': model.state_dict(), # f32 weights
'config': {
- 'kernels': [1, 3, 5],
- 'channels': [16, 8, 4],
- 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ 'kernel_size': 3,
+ 'num_layers': 3,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
},
'epoch': epoch,
'loss': loss.item()
@@ -342,59 +341,36 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
- Training uses f32 throughout (PyTorch standard)
- Export converts to np.float16, then back to f32 for WGSL literals
- **Expected discrepancy:** <0.1% MSE (acceptable)
-- Validation via `validate_cnn_v2.sh` compares outputs
+- Validation via HTML tool (see below)
---
## Validation Workflow
-### Script: `scripts/validate_cnn_v2.sh`
-
-**End-to-end pipeline:**
-```bash
-./scripts/validate_cnn_v2.sh checkpoints/checkpoint_epoch_5000.pth
-```
+### HTML Tool: `tools/cnn_v2_test/index.html`
-**Steps automated:**
-1. Export checkpoint → .wgsl shaders
-2. Rebuild `cnn_test` tool
-3. Process test images with CNN v2
-4. Display input/output results
+**WebGPU-based testing tool** with layer visualization.
**Usage:**
-```bash
-# Basic usage
-./scripts/validate_cnn_v2.sh checkpoint.pth
-
-# Custom paths
-./scripts/validate_cnn_v2.sh checkpoint.pth \
- -i my_test_images/ \
- -o results/ \
- -b build_release
+1. Open `tools/cnn_v2_test/index.html` in browser
+2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`)
+3. Drop PNG test image
+4. View results with layer inspection
-# Skip rebuild (iterate on checkpoint only)
-./scripts/validate_cnn_v2.sh checkpoint.pth --skip-build
+**Features:**
+- Live CNN inference with WebGPU
+- Layer-by-layer visualization (static features + all CNN layers)
+- Weight visualization (per-layer kernels)
+- View modes: CNN output, original, diff (×10)
+- Blend control for comparing with original
-# Skip export (iterate on test images only)
-./scripts/validate_cnn_v2.sh checkpoint.pth --skip-export
-
-# Show help
-./scripts/validate_cnn_v2.sh --help
+**Export weights:**
+```bash
+./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \
+ --output-weights workspaces/main/cnn_v2_weights.bin
```
-**Options:**
-- `-b, --build-dir DIR` - Build directory (default: build)
-- `-w, --workspace NAME` - Workspace name (default: main)
-- `-i, --images DIR` - Test images directory (default: training/validation)
-- `-o, --output DIR` - Output directory (default: validation_results)
-- `--skip-build` - Use existing cnn_test binary
-- `--skip-export` - Use existing .wgsl shaders
-- `-h, --help` - Show full usage
-
-**Output:**
-- Input images: `<test_images_dir>/*.png`
-- Output images: `<output_dir>/*_output.png`
-- Opens results directory in system file browser
+See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation
---
@@ -460,7 +436,7 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
### Phase 4: Tools & Validation
-- [ ] `scripts/validate_cnn_v2.sh` - End-to-end validation
+- [x] HTML validation tool - WebGPU inference with layer visualization
- [ ] Command-line argument parsing
- [ ] Shader export orchestration
- [ ] Build orchestration
@@ -506,8 +482,8 @@ training/train_cnn_v2.py # Training script
training/export_cnn_v2_shader.py # Shader generator
training/validation/ # Test images directory
-# Scripts
-scripts/validate_cnn_v2.sh # End-to-end validation
+# Validation
+tools/cnn_v2_test/index.html # WebGPU validation tool
# Documentation
doc/CNN_V2.md # This file
diff --git a/doc/HOWTO.md b/doc/HOWTO.md
index 1ae1d94..e909a5d 100644
--- a/doc/HOWTO.md
+++ b/doc/HOWTO.md
@@ -179,14 +179,9 @@ Storage buffer architecture allows dynamic layer count.
**TODO:** 8-bit quantization for 2× size reduction (~1.6 KB). Requires quantization-aware training (QAT).
-# Options:
-# -i DIR Test images directory (default: training/validation)
-# -o DIR Output directory (default: validation_results)
-# --skip-build Use existing cnn_test binary
-# -h Show all options
```
-See `scripts/validate_cnn_v2.sh --help` for full usage. See `doc/CNN_V2.md` for design details.
+**Validation:** Use HTML tool (`tools/cnn_v2_test/index.html`) for CNN v2 validation. See `doc/CNN_V2_WEB_TOOL.md`.
---
diff --git a/scripts/validate_cnn_v2.sh b/scripts/validate_cnn_v2.sh
deleted file mode 100755
index 06a4e01..0000000
--- a/scripts/validate_cnn_v2.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash
-# CNN v2 Validation - End-to-end pipeline
-
-set -e
-PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
-BUILD_DIR="$PROJECT_ROOT/build"
-WORKSPACE="main"
-
-usage() {
- echo "Usage: $0 <checkpoint.pth> [options]"
- echo "Options:"
- echo " -i DIR Test images (default: training/validation)"
- echo " -o DIR Output (default: validation_results)"
- echo " --skip-build Skip rebuild"
- exit 1
-}
-
-[ $# -eq 0 ] && usage
-CHECKPOINT="$1"
-shift
-
-TEST_IMAGES="$PROJECT_ROOT/training/validation"
-OUTPUT="$PROJECT_ROOT/validation_results"
-SKIP_BUILD=false
-
-while [[ $# -gt 0 ]]; do
- case $1 in
- -i) TEST_IMAGES="$2"; shift 2 ;;
- -o) OUTPUT="$2"; shift 2 ;;
- --skip-build) SKIP_BUILD=true; shift ;;
- -h) usage ;;
- *) usage ;;
- esac
-done
-
-echo "=== CNN v2 Validation ==="
-echo "Checkpoint: $CHECKPOINT"
-
-# Export
-echo "[1/3] Exporting shaders..."
-python3 "$PROJECT_ROOT/training/export_cnn_v2_shader.py" "$CHECKPOINT" \
- --output-dir "$PROJECT_ROOT/workspaces/$WORKSPACE/shaders"
-
-# Build
-if [ "$SKIP_BUILD" = false ]; then
- echo "[2/3] Building..."
- cmake --build "$BUILD_DIR" -j4 --target cnn_test >/dev/null 2>&1
-fi
-
-# Process
-echo "[3/3] Processing images..."
-mkdir -p "$OUTPUT"
-count=0
-for img in "$TEST_IMAGES"/*.png; do
- [ -f "$img" ] || continue
- name=$(basename "$img" .png)
- "$BUILD_DIR/cnn_test" "$img" "$OUTPUT/${name}_output.png" 2>/dev/null && count=$((count+1))
-done
-
-echo "Done! Processed $count images → $OUTPUT"
diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html
index 9ce3d8c..199deea 100644
--- a/tools/cnn_v2_test/index.html
+++ b/tools/cnn_v2_test/index.html
@@ -3,6 +3,12 @@
<!--
CNN v2 Testing Tool - WebGPU-based inference validator
+ Architecture:
+ - Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10*uv_x), bias
+ - Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels
+ - Layer 1+: previous (4D) + static (8D) = 12D → 4 channels
+ - All layers: uniform 12D input, 4D output (ping-pong buffer)
+
Features:
- Side panel: .bin metadata display, weight statistics per layer
- Layer inspection: 4-channel grayscale split, intermediate layer visualization
@@ -318,21 +324,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u,
+ 0u
);
}
@@ -350,16 +354,16 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; }
let kernel_size = params.kernel_size;
- let in_channels = params.in_channels;
- let out_channels = params.out_channels;
+ let in_channels = params.in_channels; // Always 12 (4 prev + 8 static)
+ let out_channels = params.out_channels; // Always 4
let weight_offset = params.weight_offset;
let is_output = params.is_output_layer != 0u;
let kernel_radius = i32(kernel_size / 2u);
let static_feat = unpack_static_features(coord);
- var output: array<f32, 8>;
- for (var c: u32 = 0u; c < out_channels && c < 8u; c++) {
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
var sum: f32 = 0.0;
for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
@@ -375,19 +379,20 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let kx_idx = u32(kx + kernel_radius);
let spatial_idx = ky_idx * kernel_size + kx_idx;
- for (var i: u32 = 0u; i < 8u; i++) {
+ // Previous layer channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {
let w_idx = weight_offset +
c * in_channels * kernel_size * kernel_size +
i * kernel_size * kernel_size + spatial_idx;
- sum += get_weight(w_idx) * static_local[i];
+ sum += get_weight(w_idx) * layer_local[i];
}
- let prev_channels = in_channels - 8u;
- for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) {
+ // Static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
let w_idx = weight_offset +
c * in_channels * kernel_size * kernel_size +
- (8u + i) * kernel_size * kernel_size + spatial_idx;
- sum += get_weight(w_idx) * layer_local[i];
+ (4u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
}
}
}
@@ -399,17 +404,13 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
}
}
- for (var c: u32 = out_channels; c < 8u; c++) {
- output[c] = 0.0;
- }
-
if (is_output) {
let original = textureLoad(original_input, coord, 0).rgb;
- let result_rgb = vec3<f32>(output[0], output[1], output[2]);
+ let result_rgb = vec3<f32>(output.x, output.y, output.z);
let blended = mix(original, result_rgb, params.blend_amount);
- output[0] = blended.r;
- output[1] = blended.g;
- output[2] = blended.b;
+ output.x = blended.r;
+ output.y = blended.g;
+ output.z = blended.b;
}
textureStore(output_tex, coord, pack_channels(output));
@@ -1013,7 +1014,7 @@ class CNNTester {
</div>
`;
- html += '<div style="font-size: 9px; color: #808080; margin-bottom: 8px; padding-bottom: 8px; border-bottom: 1px solid #404040;">Static features (7D input) + ${this.weights.layers.length} CNN layers. Showing first 4 of 8 channels.</div>';
+ html += `<div style="font-size: 9px; color: #808080; margin-bottom: 8px; padding-bottom: 8px; border-bottom: 1px solid #404040;">Static features (8D: p0-p3 + spatial) + ${this.weights.layers.length} CNN layers. All layers: 12D→4D.</div>`;
html += '<div class="layer-buttons">';
for (let i = 0; i < this.layerOutputs.length; i++) {
@@ -1116,10 +1117,10 @@ class CNNTester {
this.log(`Visualizing ${layerName} activations (${width}×${height})`);
// Update channel labels based on layer type
- // Static features: 8 channels total (R,G,B,D,UV_X,UV_Y,sin,bias), showing first 4
- // CNN layers: Up to 8 channels per layer, showing first 4
+ // Static features: 8 channels (p0,p1,p2,p3,uv_x,uv_y,sin10_x,bias)
+ // CNN layers: 4 channels per layer (uniform)
const channelLabels = layerIdx === 0
- ? ['Ch0 (R)', 'Ch1 (G)', 'Ch2 (B)', 'Ch3 (D)']
+ ? ['Ch0 (p0)', 'Ch1 (p1)', 'Ch2 (p2)', 'Ch3 (p3)']
: ['Ch0', 'Ch1', 'Ch2', 'Ch3'];
for (let c = 0; c < 4; c++) {
@@ -1169,7 +1170,7 @@ class CNNTester {
continue;
}
- const vizScale = layerIdx === 0 ? 1.0 : 0.2; // Static: 1.0, CNN layers: 0.2 (assumes ~5 max)
+ const vizScale = layerIdx === 0 ? 1.0 : 0.5; // Static: 1.0, CNN layers: 0.5 (4 channels [0,1])
const paramsBuffer = this.device.createBuffer({
size: 8,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py
index add28d2..ad5749c 100755
--- a/training/export_cnn_v2_shader.py
+++ b/training/export_cnn_v2_shader.py
@@ -1,8 +1,11 @@
#!/usr/bin/env python3
-"""CNN v2 Shader Export Script
+"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
Generates one shader per layer with embedded weight arrays.
+
+Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
+ This script is for debugging/testing with per-layer shaders.
"""
import argparse
@@ -11,16 +14,13 @@ import torch
from pathlib import Path
-def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels,
- output_dir, is_output_layer=False):
+def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_layer=False):
"""Generate WGSL compute shader for a single CNN layer.
Args:
- layer_idx: Layer index (0, 1, 2)
- weights: (out_ch, in_ch, k, k) weight tensor
- kernel_size: Kernel size (1, 3, 5, etc.)
- in_channels: Input channels (includes 8D static features)
- out_channels: Output channels
+ layer_idx: Layer index (0, 1, 2, ...)
+ weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
+ kernel_size: Kernel size (3, 5, etc.)
output_dir: Output directory path
is_output_layer: True if this is the final RGBA output layer
"""
@@ -39,12 +39,12 @@ def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channe
if is_output_layer:
activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation"
- shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated
-// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels}
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
+// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
const KERNEL_SIZE: u32 = {kernel_size}u;
-const IN_CHANNELS: u32 = {in_channels}u;
-const OUT_CHANNELS: u32 = {out_channels}u;
+const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
+const OUT_CHANNELS: u32 = 4u; // Uniform output
const KERNEL_RADIUS: i32 = {radius};
// Weights quantized to float16 (stored as f32 in WGSL)
@@ -65,21 +65,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}}
@@ -95,9 +93,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
// Load static features (always available)
let static_feat = unpack_static_features(coord);
- // Convolution
- var output: array<f32, OUT_CHANNELS>;
- for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {{
var sum: f32 = 0.0;
for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
@@ -110,28 +108,27 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
- // Load input features
+ // Load features at this spatial location
let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
// Weight index calculation
let ky_idx = u32(ky + KERNEL_RADIUS);
let kx_idx = u32(kx + KERNEL_RADIUS);
let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
- // Accumulate: static features (8D)
- for (var i: u32 = 0u; i < 8u; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * static_local[i];
+ sum += weights[w_idx] * layer_local[i];
}}
- // Accumulate: layer input channels (if layer_idx > 0)
- let prev_channels = IN_CHANNELS - 8u;
- for (var i: u32 = 0u; i < prev_channels; i++) {{
- let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
- (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
- sum += weights[w_idx] * layer_local[i];
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
+ (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
}}
}}
}}
@@ -162,53 +159,37 @@ def export_checkpoint(checkpoint_path, output_dir):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
- print(f" Features: {config['features']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nExporting shaders to {output_dir}/")
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=0,
- weights=layer0_weights,
- kernel_size=config['kernels'][0],
- in_channels=8,
- out_channels=config['channels'][0],
- output_dir=output_dir,
- is_output_layer=False
- )
+ # All layers uniform: 12D→4D
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=1,
- weights=layer1_weights,
- kernel_size=config['kernels'][1],
- in_channels=8 + config['channels'][0],
- out_channels=config['channels'][1],
- output_dir=output_dir,
- is_output_layer=False
- )
+ layer_weights = state_dict[layer_key].detach().numpy()
+ is_output = (i == num_layers - 1)
- # Layer 2: (8 + channels[1]) → 4 (RGBA)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- export_layer_shader(
- layer_idx=2,
- weights=layer2_weights,
- kernel_size=config['kernels'][2],
- in_channels=8 + config['channels'][1],
- out_channels=4,
- output_dir=output_dir,
- is_output_layer=True
- )
+ export_layer_shader(
+ layer_idx=i,
+ weights=layer_weights,
+ kernel_size=kernel_size,
+ output_dir=output_dir,
+ is_output_layer=is_output
+ )
- print(f"\nExport complete! Generated 3 shader files.")
+ print(f"\nExport complete! Generated {num_layers} shader files.")
def main():
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py
index 8a2fcdc..07254fc 100755
--- a/training/export_cnn_v2_weights.py
+++ b/training/export_cnn_v2_weights.py
@@ -45,53 +45,38 @@ def export_weights_binary(checkpoint_path, output_path):
state_dict = checkpoint['model_state_dict']
config = checkpoint['config']
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+
print(f"Configuration:")
- print(f" Kernels: {config['kernels']}")
- print(f" Channels: {config['channels']}")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Architecture: uniform 12D→4D (bias=False)")
- # Collect layer info
+ # Collect layer info - all layers uniform 12D→4D
layers = []
all_weights = []
weight_offset = 0
- # Layer 0: 8 → channels[0]
- layer0_weights = state_dict['layer0.weight'].detach().numpy()
- layer0_flat = layer0_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][0],
- 'in_channels': 8,
- 'out_channels': config['channels'][0],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer0_flat)
- })
- all_weights.extend(layer0_flat)
- weight_offset += len(layer0_flat)
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
+
+ layer_weights = state_dict[layer_key].detach().numpy()
+ layer_flat = layer_weights.flatten()
- # Layer 1: (8 + channels[0]) → channels[1]
- layer1_weights = state_dict['layer1.weight'].detach().numpy()
- layer1_flat = layer1_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][1],
- 'in_channels': 8 + config['channels'][0],
- 'out_channels': config['channels'][1],
- 'weight_offset': weight_offset,
- 'weight_count': len(layer1_flat)
- })
- all_weights.extend(layer1_flat)
- weight_offset += len(layer1_flat)
+ layers.append({
+ 'kernel_size': kernel_size,
+ 'in_channels': 12, # 4 (input/prev) + 8 (static)
+ 'out_channels': 4, # Uniform output
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer_flat)
+ })
+ all_weights.extend(layer_flat)
+ weight_offset += len(layer_flat)
- # Layer 2: (8 + channels[1]) → 4 (RGBA output)
- layer2_weights = state_dict['layer2.weight'].detach().numpy()
- layer2_flat = layer2_weights.flatten()
- layers.append({
- 'kernel_size': config['kernels'][2],
- 'in_channels': 8 + config['channels'][1],
- 'out_channels': 4,
- 'weight_offset': weight_offset,
- 'weight_count': len(layer2_flat)
- })
- all_weights.extend(layer2_flat)
- weight_offset += len(layer2_flat)
+ print(f" Layer {i}: 12D→4D, {len(layer_flat)} weights")
# Convert to f16
# TODO: Use 8-bit quantization for 2× size reduction
@@ -183,21 +168,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}
@@ -238,9 +221,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let out_channels = weights[layer0_info_base + 2u];
let weight_offset = weights[layer0_info_base + 3u];
- // Convolution (simplified - expand to full kernel loop)
- var output: array<f32, 8>;
- for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
output[c] = 0.0; // TODO: Actual convolution
}
diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py
index 758b044..8b3b91c 100755
--- a/training/train_cnn_v2.py
+++ b/training/train_cnn_v2.py
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
-"""CNN v2 Training Script - Parametric Static Features
+"""CNN v2 Training Script - Uniform 12D→4D Architecture
-Trains a multi-layer CNN with 7D static feature input:
-- RGBD (4D)
-- UV coordinates (2D)
-- sin(10*uv.x) position encoding (1D)
-- Bias dimension (1D, always 1.0)
+Architecture:
+- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias
+- Input RGBD (4D): original image mip 0
+- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+- Uniform layer structure with bias=False (bias in static features)
"""
import argparse
@@ -21,20 +21,26 @@ import cv2
def compute_static_features(rgb, depth=None):
- """Generate 7D static features + bias dimension.
+ """Generate 8D static features (parametric + spatial).
Args:
rgb: (H, W, 3) RGB image [0, 1]
depth: (H, W) depth map [0, 1], optional
Returns:
- (H, W, 8) static features tensor
+ (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]
+
+ Note: p0-p3 are parametric features (can be mips, gradients, etc.)
+ For training, we use RGBD as default, but could use mip1/2
"""
h, w = rgb.shape[:2]
- # RGBD channels
- r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
- d = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
+ # Parametric features (p0-p3) - using RGBD as default
+ # TODO: Experiment with mip1 grayscale, gradients, etc.
+ p0 = rgb[:, :, 0].astype(np.float32)
+ p1 = rgb[:, :, 1].astype(np.float32)
+ p2 = rgb[:, :, 2].astype(np.float32)
+ p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32)
# UV coordinates (normalized [0, 1])
uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
@@ -43,65 +49,64 @@ def compute_static_features(rgb, depth=None):
# Multi-frequency position encoding
sin10_x = np.sin(10.0 * uv_x).astype(np.float32)
- # Bias dimension (always 1.0)
+ # Bias dimension (always 1.0) - replaces Conv2d bias parameter
bias = np.ones((h, w), dtype=np.float32)
- # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias]
- features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1)
+ # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
+ features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1)
return features
class CNNv2(nn.Module):
- """CNN v2 with parametric static features.
+ """CNN v2 - Uniform 12D→4D Architecture
+
+ All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+ Uses bias=False (bias integrated in static features as 1.0)
TODO: Add quantization-aware training (QAT) for 8-bit weights
- Use torch.quantization.QuantStub/DeQuantStub
- Train with fake quantization to adapt to 8-bit precision
- - Target: ~1.6 KB weights (vs 3.2 KB with f16)
+ - Target: ~1.3 KB weights (vs 2.6 KB with f16)
"""
- def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]):
+ def __init__(self, kernel_size=3, num_layers=3):
super().__init__()
- self.kernels = kernels
- self.channels = channels
-
- # Input layer: 8D (7 features + bias) → channels[0]
- self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0],
- padding=kernels[0]//2, bias=False)
-
- # Inner layers: (8 + C_prev) → C_next
- in_ch_1 = 8 + channels[0]
- self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1],
- padding=kernels[1]//2, bias=False)
+ self.kernel_size = kernel_size
+ self.num_layers = num_layers
+ self.layers = nn.ModuleList()
- # Output layer: (8 + C_last) → 4 (RGBA)
- in_ch_2 = 8 + channels[1]
- self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2],
- padding=kernels[2]//2, bias=False)
+ # All layers: 12D input (4 RGBD + 8 static) → 4D output
+ for _ in range(num_layers):
+ self.layers.append(
+ nn.Conv2d(12, 4, kernel_size=kernel_size,
+ padding=kernel_size//2, bias=False)
+ )
- def forward(self, static_features):
- """Forward pass with static feature concatenation.
+ def forward(self, input_rgbd, static_features):
+ """Forward pass with uniform 12D→4D layers.
Args:
+ input_rgbd: (B, 4, H, W) input image RGBD (mip 0)
static_features: (B, 8, H, W) static features
Returns:
(B, 4, H, W) RGBA output [0, 1]
"""
- # Layer 0: Use full 8D static features
- x0 = self.layer0(static_features)
- x0 = F.relu(x0)
+ # Layer 0: input RGBD (4D) + static (8D) = 12D
+ x = torch.cat([input_rgbd, static_features], dim=1)
+ x = self.layers[0](x)
+ x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0
- # Layer 1: Concatenate static + layer0 output
- x1_input = torch.cat([static_features, x0], dim=1)
- x1 = self.layer1(x1_input)
- x1 = F.relu(x1)
+ # Layer 1+: previous (4D) + static (8D) = 12D
+ for i in range(1, self.num_layers):
+ x_input = torch.cat([x, static_features], dim=1)
+ x = self.layers[i](x_input)
+ if i < self.num_layers - 1:
+ x = F.relu(x)
+ else:
+ x = torch.clamp(x, 0, 1) # Final output [0,1]
- # Layer 2: Concatenate static + layer1 output
- x2_input = torch.cat([static_features, x1], dim=1)
- output = self.layer2(x2_input)
-
- return torch.sigmoid(output)
+ return x
class PatchDataset(Dataset):
@@ -214,14 +219,18 @@ class PatchDataset(Dataset):
# Compute static features for patch
static_feat = compute_static_features(input_patch.astype(np.float32))
+ # Input RGBD (mip 0) - add depth channel
+ input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
+
# Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1)
# Pad target to 4 channels (RGBA)
target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
- return static_feat, target
+ return input_rgbd, static_feat, target
class ImagePairDataset(Dataset):
@@ -252,14 +261,19 @@ class ImagePairDataset(Dataset):
# Compute static features
static_feat = compute_static_features(input_img.astype(np.float32))
+ # Input RGBD (mip 0) - add depth channel
+ h, w = input_img.shape[:2]
+ input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1)
+
# Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1)
# Pad target to 4 channels (RGBA)
target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0)
- return static_feat, target
+ return input_rgbd, static_feat, target
def train(args):
@@ -282,9 +296,10 @@ def train(args):
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# Create model
- model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device)
+ model = CNNv2(kernel_size=args.kernel_size, num_layers=args.num_layers).to(device)
total_params = sum(p.numel() for p in model.parameters())
- print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights")
+ weights_per_layer = 12 * args.kernel_size * args.kernel_size * 4
+ print(f"Model: {args.num_layers} layers, {args.kernel_size}×{args.kernel_size} kernels, {total_params} weights ({weights_per_layer}/layer)")
# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
@@ -298,12 +313,13 @@ def train(args):
model.train()
epoch_loss = 0.0
- for static_feat, target in dataloader:
+ for input_rgbd, static_feat, target in dataloader:
+ input_rgbd = input_rgbd.to(device)
static_feat = static_feat.to(device)
target = target.to(device)
optimizer.zero_grad()
- output = model(static_feat)
+ output = model(input_rgbd, static_feat)
loss = criterion(output, target)
loss.backward()
optimizer.step()
@@ -327,9 +343,9 @@ def train(args):
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'config': {
- 'kernels': args.kernel_sizes,
- 'channels': args.channels,
- 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ 'kernel_size': args.kernel_size,
+ 'num_layers': args.num_layers,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
}
}, checkpoint_path)
print(f" → Saved checkpoint: {checkpoint_path}")
@@ -361,10 +377,10 @@ def main():
# Mix salient points with random samples for better generalization
# Model architecture
- parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5],
- help='Kernel sizes for 3 layers (default: 1 3 5)')
- parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4],
- help='Output channels for 3 layers (default: 16 8 4)')
+ parser.add_argument('--kernel-size', type=int, default=3,
+ help='Kernel size (uniform for all layers, default: 3)')
+ parser.add_argument('--num-layers', type=int, default=3,
+ help='Number of CNN layers (default: 3)')
# Training parameters
parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl
index 1e1704d..5c4b113 100644
--- a/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl
+++ b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl
@@ -1,6 +1,6 @@
-// CNN v2 Compute Shader - Storage Buffer Version
-// Processes single layer per dispatch with weights from storage buffer
-// Multi-layer execution handled by C++ with ping-pong buffers
+// CNN v2 Compute Shader - Uniform 12D→4D Architecture
+// All layers: input/previous (4D) + static (8D) = 12D → 4 channels
+// Storage buffer weights, ping-pong execution
// Push constants for layer parameters (passed per dispatch)
struct LayerParams {
@@ -12,12 +12,12 @@ struct LayerParams {
blend_amount: f32, // [0,1] blend with original
}
-@group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features
-@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output (8-channel packed)
-@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output
+@group(0) @binding(0) var static_features: texture_2d<u32>; // 8D static features (p0-p3 + spatial)
+@group(0) @binding(1) var layer_input: texture_2d<u32>; // 4D previous/input (RGBD or prev layer)
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // 4D output
@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
@group(0) @binding(4) var<uniform> params: LayerParams;
-@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB input for blending
+@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB for blending
fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
let packed = textureLoad(static_features, coord, 0);
@@ -28,21 +28,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
-fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
- let v2 = unpack2x16float(packed.z);
- let v3 = unpack2x16float(packed.w);
- return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
-fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
- pack2x16float(vec2<f32>(values[0], values[1])),
- pack2x16float(vec2<f32>(values[2], values[3])),
- pack2x16float(vec2<f32>(values[4], values[5])),
- pack2x16float(vec2<f32>(values[6], values[7]))
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
);
}
@@ -68,19 +66,19 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
}
let kernel_size = params.kernel_size;
- let in_channels = params.in_channels;
- let out_channels = params.out_channels;
+ let in_channels = params.in_channels; // Always 12 (4 prev + 8 static)
+ let out_channels = params.out_channels; // Always 4
let weight_offset = params.weight_offset;
let is_output = params.is_output_layer != 0u;
let kernel_radius = i32(kernel_size / 2u);
- // Load static features (always 8D)
+ // Load static features (8D) and previous/input layer (4D)
let static_feat = unpack_static_features(coord);
- // Convolution per output channel
- var output: array<f32, 8>;
- for (var c: u32 = 0u; c < out_channels && c < 8u; c++) {
+ // Convolution: 12D input → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
var sum: f32 = 0.0;
// Convolve over kernel
@@ -94,55 +92,49 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
- // Load input features at this spatial location
+ // Load features at this spatial location
let static_local = unpack_static_features(clamped);
- let layer_local = unpack_layer_channels(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
// Weight index calculation
let ky_idx = u32(ky + kernel_radius);
let kx_idx = u32(kx + kernel_radius);
let spatial_idx = ky_idx * kernel_size + kx_idx;
- // Accumulate: static features (always 8 channels)
- for (var i: u32 = 0u; i < 8u; i++) {
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {
let w_idx = weight_offset +
- c * in_channels * kernel_size * kernel_size +
+ c * 12u * kernel_size * kernel_size +
i * kernel_size * kernel_size + spatial_idx;
- sum += get_weight(w_idx) * static_local[i];
+ sum += get_weight(w_idx) * layer_local[i];
}
- // Accumulate: previous layer channels (in_channels - 8)
- let prev_channels = in_channels - 8u;
- for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) {
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
let w_idx = weight_offset +
- c * in_channels * kernel_size * kernel_size +
- (8u + i) * kernel_size * kernel_size + spatial_idx;
- sum += get_weight(w_idx) * layer_local[i];
+ c * 12u * kernel_size * kernel_size +
+ (4u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
}
}
}
// Activation
if (is_output) {
- output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation
+ output[c] = clamp(sum, 0.0, 1.0);
} else {
output[c] = max(0.0, sum); // ReLU
}
}
- // Zero unused channels
- for (var c: u32 = out_channels; c < 8u; c++) {
- output[c] = 0.0;
- }
-
// Blend with original on final layer
if (is_output) {
let original = textureLoad(original_input, coord, 0).rgb;
- let result_rgb = vec3<f32>(output[0], output[1], output[2]);
+ let result_rgb = vec3<f32>(output.x, output.y, output.z);
let blended = mix(original, result_rgb, params.blend_amount);
- output[0] = blended.r;
- output[1] = blended.g;
- output[2] = blended.b;
+ output.x = blended.r;
+ output.y = blended.g;
+ output.z = blended.b;
}
textureStore(output_tex, coord, pack_channels(output));
diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
index dd07f19..7a9e6de 100644
--- a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
+++ b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
@@ -1,5 +1,7 @@
// CNN v2 Static Features Compute Shader
-// Generates 7D features + bias: [R, G, B, D, uv.x, uv.y, sin10_x, 1.0]
+// Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
+// p0-p3: Parametric features (currently RGBD from mip0, could be mip1/2, gradients, etc.)
+// Note: Input image RGBD (mip0) fed separately to Layer 0
@group(0) @binding(0) var input_tex: texture_2d<f32>;
@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
@@ -16,14 +18,14 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
return;
}
- // Sample RGBA from mip 0
+ // Parametric features (p0-p3)
+ // TODO: Experiment with mip1 grayscale, Sobel gradients, etc.
+ // For now, use RGBD from mip 0 (same as input, but could differ)
let rgba = textureLoad(input_tex, coord, 0);
- let r = rgba.r;
- let g = rgba.g;
- let b = rgba.b;
-
- // Sample depth
- let d = textureLoad(depth_tex, coord, 0).r;
+ let p0 = rgba.r;
+ let p1 = rgba.g;
+ let p2 = rgba.b;
+ let p3 = textureLoad(depth_tex, coord, 0).r;
// UV coordinates (normalized [0,1], bottom-left origin)
let uv_x = f32(coord.x) / f32(dims.x);
@@ -36,9 +38,10 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let bias = 1.0;
// Pack 8×f16 into 4×u32 (rgba32uint)
+ // [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]
let packed = vec4<u32>(
- pack2x16float(vec2<f32>(r, g)),
- pack2x16float(vec2<f32>(b, d)),
+ pack2x16float(vec2<f32>(p0, p1)),
+ pack2x16float(vec2<f32>(p2, p3)),
pack2x16float(vec2<f32>(uv_x, uv_y)),
pack2x16float(vec2<f32>(sin10_x, bias))
);