diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-13 12:32:36 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-13 12:32:36 +0100 |
| commit | 561d1dc446db7d1d3e02b92b43abedf1a5017850 (patch) | |
| tree | ef9302dc1f9b6b9f8a12225580f2a3b07602656b | |
| parent | c27b34279c0d1c2a8f1dbceb0e154b585b5c6916 (diff) | |
CNN v2: Refactor to uniform 12D→4D architecture
**Architecture changes:**
- Static features (8D): p0-p3 (parametric) + uv_x, uv_y, sin(10×uv_x), bias
- Input RGBD (4D): fed separately to all layers
- All layers: uniform 12D→4D (4 prev/input + 8 static → 4 output)
- Bias integrated in static features (bias=False in PyTorch)
**Weight calculations:**
- 3 layers × (12 × 3×3 × 4) = 1296 weights
- f16: 2.6 KB (vs old variable arch: ~6.4 KB)
**Updated files:**
*Training (Python):*
- train_cnn_v2.py: Uniform model, takes input_rgbd + static_features
- export_cnn_v2_weights.py: Binary export for storage buffers
- export_cnn_v2_shader.py: Per-layer shader export (debugging)
*Shaders (WGSL):*
- cnn_v2_static.wgsl: p0-p3 parametric features (mips/gradients)
- cnn_v2_compute.wgsl: 12D input, 4D output, vec4 packing
*Tools:*
- HTML tool (cnn_v2_test): Updated for 12D→4D, layer visualization
*Docs:*
- CNN_V2.md: Updated architecture, training, validation sections
- HOWTO.md: Reference HTML tool for validation
*Removed:*
- validate_cnn_v2.sh: Obsolete (used CNN v1 tool)
All code consistent with bias=False (bias in static features as 1.0).
handoff(Claude): CNN v2 architecture finalized and documented
| -rw-r--r-- | doc/CNN_V2.md | 232 | ||||
| -rw-r--r-- | doc/HOWTO.md | 7 | ||||
| -rwxr-xr-x | scripts/validate_cnn_v2.sh | 60 | ||||
| -rw-r--r-- | tools/cnn_v2_test/index.html | 65 | ||||
| -rwxr-xr-x | training/export_cnn_v2_shader.py | 127 | ||||
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 85 | ||||
| -rwxr-xr-x | training/train_cnn_v2.py | 134 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl | 80 | ||||
| -rw-r--r-- | workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl | 23 |
9 files changed, 350 insertions, 463 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md index 588c3db..4612d7a 100644 --- a/doc/CNN_V2.md +++ b/doc/CNN_V2.md @@ -40,10 +40,10 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA - Lifetime: Entire frame (all CNN layer passes) **CNN Layers:** -- Input Layer: 7D static features → C₀ channels -- Inner Layers: (7D + Cᵢ₋₁) → Cᵢ channels -- Output Layer: (7D + Cₙ) → 4D RGBA -- Storage: `texture_storage_2d<rgba32uint>` (8×f16 per texel recommended) +- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels +- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels +- All layers: uniform 12D input, 4D output (ping-pong buffer) +- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs) --- @@ -54,11 +54,13 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA **8 float16 values per pixel:** ```wgsl -// Slot 0-3: RGBD (core pixel data) -let r = rgba.r; // Red channel -let g = rgba.g; // Green channel -let b = rgba.b; // Blue channel -let d = depth; // Depth value +// Slot 0-3: Parametric features (p0, p1, p2, p3) +// Can be: mip1/2 RGBD, grayscale, gradients, etc. +// Distinct from input image RGBD (fed only to Layer 0) +let p0 = ...; // Parametric feature 0 (e.g., mip1.r or grayscale) +let p1 = ...; // Parametric feature 1 +let p2 = ...; // Parametric feature 2 +let p3 = ...; // Parametric feature 3 // Slot 4-5: UV coordinates (normalized screen space) let uv_x = coord.x / resolution.x; // Horizontal position [0,1] @@ -70,18 +72,20 @@ let sin10_x = sin(10.0 * uv_x); // Periodic feature (frequency=10) // Slot 7: Bias dimension (always 1.0) let bias = 1.0; // Learned bias per output channel -// Packed storage: [R, G, B, D, uv.x, uv.y, sin(10*uv.x), 1.0] +// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(10*uv.x), 1.0] ``` ### Feature Rationale | Feature | Dimension | Purpose | Priority | |---------|-----------|---------|----------| -| RGBD | 4D | Core pixel information | Essential | +| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential | | UV coords | 2D | Spatial position awareness | Essential | | sin(10\*uv.x) | 1D | Periodic position encoding | Medium | | Bias | 1D | Learned bias (standard NN) | Essential | +**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output. + **Why bias as static feature:** - Simpler shader code (single weight array) - Standard NN formulation: y = Wx (x includes bias term) @@ -113,68 +117,65 @@ Requires quantization-aware training. ### Example 3-Layer Network ``` -Input: 7D static → 16 channels (1×1 kernel, pointwise) -Layer1: (7+16)D → 8 channels (3×3 kernel, spatial) -Layer2: (7+8)D → 4 channels (5×5 kernel, large receptive field) +Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output) ``` ### Weight Calculations -**Per-layer weights:** +**Per-layer weights (uniform 12D→4D, 3×3 kernels):** ``` -Input: 7 × 1 × 1 × 16 = 112 weights -Layer1: (7+16) × 3 × 3 × 8 = 1656 weights -Layer2: (7+8) × 5 × 5 × 4 = 1500 weights -Total: 3268 weights +Layer 0: 12 × 3 × 3 × 4 = 432 weights +Layer 1: 12 × 3 × 3 × 4 = 432 weights +Layer 2: 12 × 3 × 3 × 4 = 432 weights +Total: 1296 weights ``` **Storage sizes:** -- f32: 3268 × 4 = 13,072 bytes (~12.8 KB) -- f16: 3268 × 2 = 6,536 bytes (~6.4 KB) ✓ **recommended** +- f32: 1296 × 4 = 5,184 bytes (~5.1 KB) +- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended** **Comparison to v1:** - v1: ~800 weights (3.2 KB f32) -- v2: ~3268 weights (6.4 KB f16) -- **Growth: 2× size for parametric features** +- v2: ~1296 weights (2.5 KB f16) +- **Uniform architecture, smaller than v1 f32** ### Kernel Size Guidelines **1×1 kernel (pointwise):** - No spatial context, channel mixing only -- Weights: `(7 + C_in) × C_out` -- Use for: Input layer, bottleneck layers +- Weights: `12 × 4 = 48` per layer +- Use for: Fast inference, channel remapping **3×3 kernel (standard conv):** -- Local spatial context -- Weights: `(7 + C_in) × 9 × C_out` -- Use for: Most inner layers +- Local spatial context (recommended) +- Weights: `12 × 9 × 4 = 432` per layer +- Use for: Most layers (balanced quality/size) **5×5 kernel (large receptive field):** - Wide spatial context -- Weights: `(7 + C_in) × 25 × C_out` -- Use for: Output layer, detail enhancement +- Weights: `12 × 25 × 4 = 1200` per layer +- Use for: Output layer, fine detail enhancement -### Channel Storage (8×f16 per texel) +### Channel Storage (4×f16 per texel) ```wgsl @group(0) @binding(1) var layer_input: texture_2d<u32>; -fn unpack_channels(coord: vec2<i32>) -> array<f32, 8> { +fn unpack_channels(coord: vec2<i32>) -> vec4<f32> { let packed = textureLoad(layer_input, coord, 0); - return array( - unpack2x16float(packed.x).x, unpack2x16float(packed.x).y, - unpack2x16float(packed.y).x, unpack2x16float(packed.y).y, - unpack2x16float(packed.z).x, unpack2x16float(packed.z).y, - unpack2x16float(packed.w).x, unpack2x16float(packed.w).y - ); + let v0 = unpack2x16float(packed.x); // [ch0, ch1] + let v1 = unpack2x16float(packed.y); // [ch2, ch3] + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); } -fn pack_channels(values: array<f32, 8>) -> vec4<u32> { - return vec4( - pack2x16float(vec2(values[0], values[1])), - pack2x16float(vec2(values[2], values[3])), - pack2x16float(vec2(values[4], values[5])), - pack2x16float(vec2(values[6], values[7])) +fn pack_channels(values: vec4<f32>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2(values.x, values.y)), + pack2x16float(vec2(values.z, values.w)), + 0u, // Unused + 0u // Unused ); } ``` @@ -189,11 +190,11 @@ fn pack_channels(values: array<f32, 8>) -> vec4<u32> { ```python def compute_static_features(rgb, depth): - """Generate 7D static features + bias dimension.""" + """Generate parametric features (8D: p0-p3 + spatial).""" h, w = rgb.shape[:2] - # RGBD channels - r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + # Parametric features (example: use input RGBD, but could be mips/gradients) + p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth # UV coordinates (normalized) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) @@ -203,56 +204,51 @@ def compute_static_features(rgb, depth): sin10_x = np.sin(10.0 * uv_x) # Bias dimension (always 1.0) - bias = np.ones_like(r) + bias = np.ones_like(p0) - # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] - return np.stack([r, g, b, depth, uv_x, uv_y, sin10_x, bias], axis=-1) + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] + return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) ``` **Network Definition:** ```python class CNNv2(nn.Module): - def __init__(self, kernels=[1,3,5], channels=[16,8,4]): + def __init__(self, kernel_size=3, num_layers=3): super().__init__() + self.layers = nn.ModuleList() - # Input layer: 8D (7 features + bias) → channels[0] - self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], - padding=kernels[0]//2, bias=False) - - # Inner layers: (7 features + bias + C_prev) → C_next - in_ch_1 = 8 + channels[0] # static + layer0 output - self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], - padding=kernels[1]//2, bias=False) - - # Output layer: (7 features + bias + C_last) → 4 (RGBA) - in_ch_2 = 8 + channels[1] - self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], - padding=kernels[2]//2, bias=False) + # All layers: 12D input (4 prev + 8 static) → 4D output + for i in range(num_layers): + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) - def forward(self, static_features, layer0_input=None): - # Layer 0: Use full 8D static features (includes bias) - x0 = self.layer0(static_features) - x0 = F.relu(x0) + def forward(self, input_rgbd, static_features): + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.clamp(x, 0, 1) # Output layer 0 (4 channels) - # Layer 1: Concatenate static + layer0 output - x1_input = torch.cat([static_features, x0], dim=1) - x1 = self.layer1(x1_input) - x1 = F.relu(x1) + # Layer 1+: previous output (4D) + static (8D) = 12D + for i in range(1, len(self.layers)): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < len(self.layers) - 1: + x = F.relu(x) + else: + x = torch.clamp(x, 0, 1) # Final output [0,1] - # Layer 2: Concatenate static + layer1 output - x2_input = torch.cat([static_features, x1], dim=1) - output = self.layer2(x2_input) - - return torch.sigmoid(output) # RGBA output [0,1] + return x # RGBA output ``` **Training Configuration:** ```python # Hyperparameters -kernels = [1, 3, 5] # Per-layer kernel sizes -channels = [16, 8, 4] # Per-layer output channels +kernel_size = 3 # Uniform 3×3 kernels +num_layers = 3 # Number of CNN layers learning_rate = 1e-3 batch_size = 16 epochs = 5000 @@ -260,11 +256,14 @@ epochs = 5000 # Training loop (standard PyTorch f32) for epoch in range(epochs): for rgb_batch, depth_batch, target_batch in dataloader: - # Compute static features + # Compute static features (8D) static_feat = compute_static_features(rgb_batch, depth_batch) + # Input RGBD (4D) + input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) + # Forward pass - output = model(static_feat) + output = model(input_rgbd, static_feat) loss = criterion(output, target_batch) # Backward pass @@ -279,9 +278,9 @@ for epoch in range(epochs): torch.save({ 'state_dict': model.state_dict(), # f32 weights 'config': { - 'kernels': [1, 3, 5], - 'channels': [16, 8, 4], - 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + 'kernel_size': 3, + 'num_layers': 3, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] }, 'epoch': epoch, 'loss': loss.item() @@ -342,59 +341,36 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { - Training uses f32 throughout (PyTorch standard) - Export converts to np.float16, then back to f32 for WGSL literals - **Expected discrepancy:** <0.1% MSE (acceptable) -- Validation via `validate_cnn_v2.sh` compares outputs +- Validation via HTML tool (see below) --- ## Validation Workflow -### Script: `scripts/validate_cnn_v2.sh` - -**End-to-end pipeline:** -```bash -./scripts/validate_cnn_v2.sh checkpoints/checkpoint_epoch_5000.pth -``` +### HTML Tool: `tools/cnn_v2_test/index.html` -**Steps automated:** -1. Export checkpoint → .wgsl shaders -2. Rebuild `cnn_test` tool -3. Process test images with CNN v2 -4. Display input/output results +**WebGPU-based testing tool** with layer visualization. **Usage:** -```bash -# Basic usage -./scripts/validate_cnn_v2.sh checkpoint.pth - -# Custom paths -./scripts/validate_cnn_v2.sh checkpoint.pth \ - -i my_test_images/ \ - -o results/ \ - -b build_release +1. Open `tools/cnn_v2_test/index.html` in browser +2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`) +3. Drop PNG test image +4. View results with layer inspection -# Skip rebuild (iterate on checkpoint only) -./scripts/validate_cnn_v2.sh checkpoint.pth --skip-build +**Features:** +- Live CNN inference with WebGPU +- Layer-by-layer visualization (static features + all CNN layers) +- Weight visualization (per-layer kernels) +- View modes: CNN output, original, diff (×10) +- Blend control for comparing with original -# Skip export (iterate on test images only) -./scripts/validate_cnn_v2.sh checkpoint.pth --skip-export - -# Show help -./scripts/validate_cnn_v2.sh --help +**Export weights:** +```bash +./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \ + --output-weights workspaces/main/cnn_v2_weights.bin ``` -**Options:** -- `-b, --build-dir DIR` - Build directory (default: build) -- `-w, --workspace NAME` - Workspace name (default: main) -- `-i, --images DIR` - Test images directory (default: training/validation) -- `-o, --output DIR` - Output directory (default: validation_results) -- `--skip-build` - Use existing cnn_test binary -- `--skip-export` - Use existing .wgsl shaders -- `-h, --help` - Show full usage - -**Output:** -- Input images: `<test_images_dir>/*.png` -- Output images: `<output_dir>/*_output.png` -- Opens results directory in system file browser +See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation --- @@ -460,7 +436,7 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { ### Phase 4: Tools & Validation -- [ ] `scripts/validate_cnn_v2.sh` - End-to-end validation +- [x] HTML validation tool - WebGPU inference with layer visualization - [ ] Command-line argument parsing - [ ] Shader export orchestration - [ ] Build orchestration @@ -506,8 +482,8 @@ training/train_cnn_v2.py # Training script training/export_cnn_v2_shader.py # Shader generator training/validation/ # Test images directory -# Scripts -scripts/validate_cnn_v2.sh # End-to-end validation +# Validation +tools/cnn_v2_test/index.html # WebGPU validation tool # Documentation doc/CNN_V2.md # This file diff --git a/doc/HOWTO.md b/doc/HOWTO.md index 1ae1d94..e909a5d 100644 --- a/doc/HOWTO.md +++ b/doc/HOWTO.md @@ -179,14 +179,9 @@ Storage buffer architecture allows dynamic layer count. **TODO:** 8-bit quantization for 2× size reduction (~1.6 KB). Requires quantization-aware training (QAT). -# Options: -# -i DIR Test images directory (default: training/validation) -# -o DIR Output directory (default: validation_results) -# --skip-build Use existing cnn_test binary -# -h Show all options ``` -See `scripts/validate_cnn_v2.sh --help` for full usage. See `doc/CNN_V2.md` for design details. +**Validation:** Use HTML tool (`tools/cnn_v2_test/index.html`) for CNN v2 validation. See `doc/CNN_V2_WEB_TOOL.md`. --- diff --git a/scripts/validate_cnn_v2.sh b/scripts/validate_cnn_v2.sh deleted file mode 100755 index 06a4e01..0000000 --- a/scripts/validate_cnn_v2.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# CNN v2 Validation - End-to-end pipeline - -set -e -PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -BUILD_DIR="$PROJECT_ROOT/build" -WORKSPACE="main" - -usage() { - echo "Usage: $0 <checkpoint.pth> [options]" - echo "Options:" - echo " -i DIR Test images (default: training/validation)" - echo " -o DIR Output (default: validation_results)" - echo " --skip-build Skip rebuild" - exit 1 -} - -[ $# -eq 0 ] && usage -CHECKPOINT="$1" -shift - -TEST_IMAGES="$PROJECT_ROOT/training/validation" -OUTPUT="$PROJECT_ROOT/validation_results" -SKIP_BUILD=false - -while [[ $# -gt 0 ]]; do - case $1 in - -i) TEST_IMAGES="$2"; shift 2 ;; - -o) OUTPUT="$2"; shift 2 ;; - --skip-build) SKIP_BUILD=true; shift ;; - -h) usage ;; - *) usage ;; - esac -done - -echo "=== CNN v2 Validation ===" -echo "Checkpoint: $CHECKPOINT" - -# Export -echo "[1/3] Exporting shaders..." -python3 "$PROJECT_ROOT/training/export_cnn_v2_shader.py" "$CHECKPOINT" \ - --output-dir "$PROJECT_ROOT/workspaces/$WORKSPACE/shaders" - -# Build -if [ "$SKIP_BUILD" = false ]; then - echo "[2/3] Building..." - cmake --build "$BUILD_DIR" -j4 --target cnn_test >/dev/null 2>&1 -fi - -# Process -echo "[3/3] Processing images..." -mkdir -p "$OUTPUT" -count=0 -for img in "$TEST_IMAGES"/*.png; do - [ -f "$img" ] || continue - name=$(basename "$img" .png) - "$BUILD_DIR/cnn_test" "$img" "$OUTPUT/${name}_output.png" 2>/dev/null && count=$((count+1)) -done - -echo "Done! Processed $count images → $OUTPUT" diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html index 9ce3d8c..199deea 100644 --- a/tools/cnn_v2_test/index.html +++ b/tools/cnn_v2_test/index.html @@ -3,6 +3,12 @@ <!-- CNN v2 Testing Tool - WebGPU-based inference validator + Architecture: + - Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10*uv_x), bias + - Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels + - Layer 1+: previous (4D) + static (8D) = 12D → 4 channels + - All layers: uniform 12D input, 4D output (ping-pong buffer) + Features: - Side panel: .bin metadata display, weight statistics per layer - Layer inspection: 4-channel grayscale split, intermediate layer visualization @@ -318,21 +324,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } -fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> { let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); } -fn pack_channels(values: array<f32, 8>) -> vec4<u32> { +fn pack_channels(values: vec4<f32>) -> vec4<u32> { return vec4<u32>( - pack2x16float(vec2<f32>(values[0], values[1])), - pack2x16float(vec2<f32>(values[2], values[3])), - pack2x16float(vec2<f32>(values[4], values[5])), - pack2x16float(vec2<f32>(values[6], values[7])) + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, + 0u ); } @@ -350,16 +354,16 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; } let kernel_size = params.kernel_size; - let in_channels = params.in_channels; - let out_channels = params.out_channels; + let in_channels = params.in_channels; // Always 12 (4 prev + 8 static) + let out_channels = params.out_channels; // Always 4 let weight_offset = params.weight_offset; let is_output = params.is_output_layer != 0u; let kernel_radius = i32(kernel_size / 2u); let static_feat = unpack_static_features(coord); - var output: array<f32, 8>; - for (var c: u32 = 0u; c < out_channels && c < 8u; c++) { + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) { var sum: f32 = 0.0; for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) { for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) { @@ -375,19 +379,20 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { let kx_idx = u32(kx + kernel_radius); let spatial_idx = ky_idx * kernel_size + kx_idx; - for (var i: u32 = 0u; i < 8u; i++) { + // Previous layer channels (4D) + for (var i: u32 = 0u; i < 4u; i++) { let w_idx = weight_offset + c * in_channels * kernel_size * kernel_size + i * kernel_size * kernel_size + spatial_idx; - sum += get_weight(w_idx) * static_local[i]; + sum += get_weight(w_idx) * layer_local[i]; } - let prev_channels = in_channels - 8u; - for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) { + // Static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { let w_idx = weight_offset + c * in_channels * kernel_size * kernel_size + - (8u + i) * kernel_size * kernel_size + spatial_idx; - sum += get_weight(w_idx) * layer_local[i]; + (4u + i) * kernel_size * kernel_size + spatial_idx; + sum += get_weight(w_idx) * static_local[i]; } } } @@ -399,17 +404,13 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { } } - for (var c: u32 = out_channels; c < 8u; c++) { - output[c] = 0.0; - } - if (is_output) { let original = textureLoad(original_input, coord, 0).rgb; - let result_rgb = vec3<f32>(output[0], output[1], output[2]); + let result_rgb = vec3<f32>(output.x, output.y, output.z); let blended = mix(original, result_rgb, params.blend_amount); - output[0] = blended.r; - output[1] = blended.g; - output[2] = blended.b; + output.x = blended.r; + output.y = blended.g; + output.z = blended.b; } textureStore(output_tex, coord, pack_channels(output)); @@ -1013,7 +1014,7 @@ class CNNTester { </div> `; - html += '<div style="font-size: 9px; color: #808080; margin-bottom: 8px; padding-bottom: 8px; border-bottom: 1px solid #404040;">Static features (7D input) + ${this.weights.layers.length} CNN layers. Showing first 4 of 8 channels.</div>'; + html += `<div style="font-size: 9px; color: #808080; margin-bottom: 8px; padding-bottom: 8px; border-bottom: 1px solid #404040;">Static features (8D: p0-p3 + spatial) + ${this.weights.layers.length} CNN layers. All layers: 12D→4D.</div>`; html += '<div class="layer-buttons">'; for (let i = 0; i < this.layerOutputs.length; i++) { @@ -1116,10 +1117,10 @@ class CNNTester { this.log(`Visualizing ${layerName} activations (${width}×${height})`); // Update channel labels based on layer type - // Static features: 8 channels total (R,G,B,D,UV_X,UV_Y,sin,bias), showing first 4 - // CNN layers: Up to 8 channels per layer, showing first 4 + // Static features: 8 channels (p0,p1,p2,p3,uv_x,uv_y,sin10_x,bias) + // CNN layers: 4 channels per layer (uniform) const channelLabels = layerIdx === 0 - ? ['Ch0 (R)', 'Ch1 (G)', 'Ch2 (B)', 'Ch3 (D)'] + ? ['Ch0 (p0)', 'Ch1 (p1)', 'Ch2 (p2)', 'Ch3 (p3)'] : ['Ch0', 'Ch1', 'Ch2', 'Ch3']; for (let c = 0; c < 4; c++) { @@ -1169,7 +1170,7 @@ class CNNTester { continue; } - const vizScale = layerIdx === 0 ? 1.0 : 0.2; // Static: 1.0, CNN layers: 0.2 (assumes ~5 max) + const vizScale = layerIdx === 0 ? 1.0 : 0.5; // Static: 1.0, CNN layers: 0.5 (4 channels [0,1]) const paramsBuffer = this.device.createBuffer({ size: 8, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST diff --git a/training/export_cnn_v2_shader.py b/training/export_cnn_v2_shader.py index add28d2..ad5749c 100755 --- a/training/export_cnn_v2_shader.py +++ b/training/export_cnn_v2_shader.py @@ -1,8 +1,11 @@ #!/usr/bin/env python3 -"""CNN v2 Shader Export Script +"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture Converts PyTorch checkpoints to WGSL compute shaders with f16 weights. Generates one shader per layer with embedded weight arrays. + +Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size. + This script is for debugging/testing with per-layer shaders. """ import argparse @@ -11,16 +14,13 @@ import torch from pathlib import Path -def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channels, - output_dir, is_output_layer=False): +def export_layer_shader(layer_idx, weights, kernel_size, output_dir, is_output_layer=False): """Generate WGSL compute shader for a single CNN layer. Args: - layer_idx: Layer index (0, 1, 2) - weights: (out_ch, in_ch, k, k) weight tensor - kernel_size: Kernel size (1, 3, 5, etc.) - in_channels: Input channels (includes 8D static features) - out_channels: Output channels + layer_idx: Layer index (0, 1, 2, ...) + weights: (4, 12, k, k) weight tensor (uniform 12D→4D) + kernel_size: Kernel size (3, 5, etc.) output_dir: Output directory path is_output_layer: True if this is the final RGBA output layer """ @@ -39,12 +39,12 @@ def export_layer_shader(layer_idx, weights, kernel_size, in_channels, out_channe if is_output_layer: activation = "output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation" - shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated -// Kernel: {kernel_size}×{kernel_size}, In: {in_channels}, Out: {out_channels} + shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D) +// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D const KERNEL_SIZE: u32 = {kernel_size}u; -const IN_CHANNELS: u32 = {in_channels}u; -const OUT_CHANNELS: u32 = {out_channels}u; +const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static) +const OUT_CHANNELS: u32 = 4u; // Uniform output const KERNEL_RADIUS: i32 = {radius}; // Weights quantized to float16 (stored as f32 in WGSL) @@ -65,21 +65,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); }} -fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {{ +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{ let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); }} -fn pack_channels(values: array<f32, 8>) -> vec4<u32> {{ +fn pack_channels(values: vec4<f32>) -> vec4<u32> {{ return vec4<u32>( - pack2x16float(vec2<f32>(values[0], values[1])), - pack2x16float(vec2<f32>(values[2], values[3])), - pack2x16float(vec2<f32>(values[4], values[5])), - pack2x16float(vec2<f32>(values[6], values[7])) + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, // Unused + 0u // Unused ); }} @@ -95,9 +93,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ // Load static features (always available) let static_feat = unpack_static_features(coord); - // Convolution - var output: array<f32, OUT_CHANNELS>; - for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {{ + // Convolution: 12D input (4 prev + 8 static) → 4D output + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) {{ var sum: f32 = 0.0; for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{ @@ -110,28 +108,27 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ clamp(sample_coord.y, 0, i32(dims.y) - 1) ); - // Load input features + // Load features at this spatial location let static_local = unpack_static_features(clamped); - let layer_local = unpack_layer_channels(clamped); + let layer_local = unpack_layer_channels(clamped); // 4D // Weight index calculation let ky_idx = u32(ky + KERNEL_RADIUS); let kx_idx = u32(kx + KERNEL_RADIUS); let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; - // Accumulate: static features (8D) - for (var i: u32 = 0u; i < 8u; i++) {{ - let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + // Accumulate: previous/input channels (4D) + for (var i: u32 = 0u; i < 4u; i++) {{ + let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * static_local[i]; + sum += weights[w_idx] * layer_local[i]; }} - // Accumulate: layer input channels (if layer_idx > 0) - let prev_channels = IN_CHANNELS - 8u; - for (var i: u32 = 0u; i < prev_channels; i++) {{ - let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + - (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; - sum += weights[w_idx] * layer_local[i]; + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) {{ + let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE + + (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; }} }} }} @@ -162,53 +159,37 @@ def export_checkpoint(checkpoint_path, output_dir): state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] + kernel_size = config.get('kernel_size', 3) + num_layers = config.get('num_layers', 3) + print(f"Configuration:") - print(f" Kernels: {config['kernels']}") - print(f" Channels: {config['channels']}") - print(f" Features: {config['features']}") + print(f" Kernel size: {kernel_size}×{kernel_size}") + print(f" Layers: {num_layers}") + print(f" Architecture: uniform 12D→4D") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"\nExporting shaders to {output_dir}/") - # Layer 0: 8 → channels[0] - layer0_weights = state_dict['layer0.weight'].detach().numpy() - export_layer_shader( - layer_idx=0, - weights=layer0_weights, - kernel_size=config['kernels'][0], - in_channels=8, - out_channels=config['channels'][0], - output_dir=output_dir, - is_output_layer=False - ) + # All layers uniform: 12D→4D + for i in range(num_layers): + layer_key = f'layers.{i}.weight' + if layer_key not in state_dict: + raise ValueError(f"Missing weights for layer {i}: {layer_key}") - # Layer 1: (8 + channels[0]) → channels[1] - layer1_weights = state_dict['layer1.weight'].detach().numpy() - export_layer_shader( - layer_idx=1, - weights=layer1_weights, - kernel_size=config['kernels'][1], - in_channels=8 + config['channels'][0], - out_channels=config['channels'][1], - output_dir=output_dir, - is_output_layer=False - ) + layer_weights = state_dict[layer_key].detach().numpy() + is_output = (i == num_layers - 1) - # Layer 2: (8 + channels[1]) → 4 (RGBA) - layer2_weights = state_dict['layer2.weight'].detach().numpy() - export_layer_shader( - layer_idx=2, - weights=layer2_weights, - kernel_size=config['kernels'][2], - in_channels=8 + config['channels'][1], - out_channels=4, - output_dir=output_dir, - is_output_layer=True - ) + export_layer_shader( + layer_idx=i, + weights=layer_weights, + kernel_size=kernel_size, + output_dir=output_dir, + is_output_layer=is_output + ) - print(f"\nExport complete! Generated 3 shader files.") + print(f"\nExport complete! Generated {num_layers} shader files.") def main(): diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index 8a2fcdc..07254fc 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -45,53 +45,38 @@ def export_weights_binary(checkpoint_path, output_path): state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] + kernel_size = config.get('kernel_size', 3) + num_layers = config.get('num_layers', 3) + print(f"Configuration:") - print(f" Kernels: {config['kernels']}") - print(f" Channels: {config['channels']}") + print(f" Kernel size: {kernel_size}×{kernel_size}") + print(f" Layers: {num_layers}") + print(f" Architecture: uniform 12D→4D (bias=False)") - # Collect layer info + # Collect layer info - all layers uniform 12D→4D layers = [] all_weights = [] weight_offset = 0 - # Layer 0: 8 → channels[0] - layer0_weights = state_dict['layer0.weight'].detach().numpy() - layer0_flat = layer0_weights.flatten() - layers.append({ - 'kernel_size': config['kernels'][0], - 'in_channels': 8, - 'out_channels': config['channels'][0], - 'weight_offset': weight_offset, - 'weight_count': len(layer0_flat) - }) - all_weights.extend(layer0_flat) - weight_offset += len(layer0_flat) + for i in range(num_layers): + layer_key = f'layers.{i}.weight' + if layer_key not in state_dict: + raise ValueError(f"Missing weights for layer {i}: {layer_key}") + + layer_weights = state_dict[layer_key].detach().numpy() + layer_flat = layer_weights.flatten() - # Layer 1: (8 + channels[0]) → channels[1] - layer1_weights = state_dict['layer1.weight'].detach().numpy() - layer1_flat = layer1_weights.flatten() - layers.append({ - 'kernel_size': config['kernels'][1], - 'in_channels': 8 + config['channels'][0], - 'out_channels': config['channels'][1], - 'weight_offset': weight_offset, - 'weight_count': len(layer1_flat) - }) - all_weights.extend(layer1_flat) - weight_offset += len(layer1_flat) + layers.append({ + 'kernel_size': kernel_size, + 'in_channels': 12, # 4 (input/prev) + 8 (static) + 'out_channels': 4, # Uniform output + 'weight_offset': weight_offset, + 'weight_count': len(layer_flat) + }) + all_weights.extend(layer_flat) + weight_offset += len(layer_flat) - # Layer 2: (8 + channels[1]) → 4 (RGBA output) - layer2_weights = state_dict['layer2.weight'].detach().numpy() - layer2_flat = layer2_weights.flatten() - layers.append({ - 'kernel_size': config['kernels'][2], - 'in_channels': 8 + config['channels'][1], - 'out_channels': 4, - 'weight_offset': weight_offset, - 'weight_count': len(layer2_flat) - }) - all_weights.extend(layer2_flat) - weight_offset += len(layer2_flat) + print(f" Layer {i}: 12D→4D, {len(layer_flat)} weights") # Convert to f16 # TODO: Use 8-bit quantization for 2× size reduction @@ -183,21 +168,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } -fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> { let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); } -fn pack_channels(values: array<f32, 8>) -> vec4<u32> { +fn pack_channels(values: vec4<f32>) -> vec4<u32> { return vec4<u32>( - pack2x16float(vec2<f32>(values[0], values[1])), - pack2x16float(vec2<f32>(values[2], values[3])), - pack2x16float(vec2<f32>(values[4], values[5])), - pack2x16float(vec2<f32>(values[6], values[7])) + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, // Unused + 0u // Unused ); } @@ -238,9 +221,9 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { let out_channels = weights[layer0_info_base + 2u]; let weight_offset = weights[layer0_info_base + 3u]; - // Convolution (simplified - expand to full kernel loop) - var output: array<f32, 8>; - for (var c: u32 = 0u; c < min(out_channels, 8u); c++) { + // Convolution: 12D input (4 prev + 8 static) → 4D output + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) { output[c] = 0.0; // TODO: Actual convolution } diff --git a/training/train_cnn_v2.py b/training/train_cnn_v2.py index 758b044..8b3b91c 100755 --- a/training/train_cnn_v2.py +++ b/training/train_cnn_v2.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -"""CNN v2 Training Script - Parametric Static Features +"""CNN v2 Training Script - Uniform 12D→4D Architecture -Trains a multi-layer CNN with 7D static feature input: -- RGBD (4D) -- UV coordinates (2D) -- sin(10*uv.x) position encoding (1D) -- Bias dimension (1D, always 1.0) +Architecture: +- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias +- Input RGBD (4D): original image mip 0 +- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels +- Uniform layer structure with bias=False (bias in static features) """ import argparse @@ -21,20 +21,26 @@ import cv2 def compute_static_features(rgb, depth=None): - """Generate 7D static features + bias dimension. + """Generate 8D static features (parametric + spatial). Args: rgb: (H, W, 3) RGB image [0, 1] depth: (H, W) depth map [0, 1], optional Returns: - (H, W, 8) static features tensor + (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] + + Note: p0-p3 are parametric features (can be mips, gradients, etc.) + For training, we use RGBD as default, but could use mip1/2 """ h, w = rgb.shape[:2] - # RGBD channels - r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] - d = depth if depth is not None else np.zeros((h, w), dtype=np.float32) + # Parametric features (p0-p3) - using RGBD as default + # TODO: Experiment with mip1 grayscale, gradients, etc. + p0 = rgb[:, :, 0].astype(np.float32) + p1 = rgb[:, :, 1].astype(np.float32) + p2 = rgb[:, :, 2].astype(np.float32) + p3 = depth if depth is not None else np.zeros((h, w), dtype=np.float32) # UV coordinates (normalized [0, 1]) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32) @@ -43,65 +49,64 @@ def compute_static_features(rgb, depth=None): # Multi-frequency position encoding sin10_x = np.sin(10.0 * uv_x).astype(np.float32) - # Bias dimension (always 1.0) + # Bias dimension (always 1.0) - replaces Conv2d bias parameter bias = np.ones((h, w), dtype=np.float32) - # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] - features = np.stack([r, g, b, d, uv_x, uv_y, sin10_x, bias], axis=-1) + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] + features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) return features class CNNv2(nn.Module): - """CNN v2 with parametric static features. + """CNN v2 - Uniform 12D→4D Architecture + + All layers: input RGBD (4D) + static (8D) = 12D → 4 channels + Uses bias=False (bias integrated in static features as 1.0) TODO: Add quantization-aware training (QAT) for 8-bit weights - Use torch.quantization.QuantStub/DeQuantStub - Train with fake quantization to adapt to 8-bit precision - - Target: ~1.6 KB weights (vs 3.2 KB with f16) + - Target: ~1.3 KB weights (vs 2.6 KB with f16) """ - def __init__(self, kernels=[1, 3, 5], channels=[16, 8, 4]): + def __init__(self, kernel_size=3, num_layers=3): super().__init__() - self.kernels = kernels - self.channels = channels - - # Input layer: 8D (7 features + bias) → channels[0] - self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], - padding=kernels[0]//2, bias=False) - - # Inner layers: (8 + C_prev) → C_next - in_ch_1 = 8 + channels[0] - self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], - padding=kernels[1]//2, bias=False) + self.kernel_size = kernel_size + self.num_layers = num_layers + self.layers = nn.ModuleList() - # Output layer: (8 + C_last) → 4 (RGBA) - in_ch_2 = 8 + channels[1] - self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], - padding=kernels[2]//2, bias=False) + # All layers: 12D input (4 RGBD + 8 static) → 4D output + for _ in range(num_layers): + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) - def forward(self, static_features): - """Forward pass with static feature concatenation. + def forward(self, input_rgbd, static_features): + """Forward pass with uniform 12D→4D layers. Args: + input_rgbd: (B, 4, H, W) input image RGBD (mip 0) static_features: (B, 8, H, W) static features Returns: (B, 4, H, W) RGBA output [0, 1] """ - # Layer 0: Use full 8D static features - x0 = self.layer0(static_features) - x0 = F.relu(x0) + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.clamp(x, 0, 1) # Output [0,1] for layer 0 - # Layer 1: Concatenate static + layer0 output - x1_input = torch.cat([static_features, x0], dim=1) - x1 = self.layer1(x1_input) - x1 = F.relu(x1) + # Layer 1+: previous (4D) + static (8D) = 12D + for i in range(1, self.num_layers): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < self.num_layers - 1: + x = F.relu(x) + else: + x = torch.clamp(x, 0, 1) # Final output [0,1] - # Layer 2: Concatenate static + layer1 output - x2_input = torch.cat([static_features, x1], dim=1) - output = self.layer2(x2_input) - - return torch.sigmoid(output) + return x class PatchDataset(Dataset): @@ -214,14 +219,18 @@ class PatchDataset(Dataset): # Compute static features for patch static_feat = compute_static_features(input_patch.astype(np.float32)) + # Input RGBD (mip 0) - add depth channel + input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1) + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) - return static_feat, target + return input_rgbd, static_feat, target class ImagePairDataset(Dataset): @@ -252,14 +261,19 @@ class ImagePairDataset(Dataset): # Compute static features static_feat = compute_static_features(input_img.astype(np.float32)) + # Input RGBD (mip 0) - add depth channel + h, w = input_img.shape[:2] + input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1) + # Convert to tensors (C, H, W) + input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1) static_feat = torch.from_numpy(static_feat).permute(2, 0, 1) target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # Pad target to 4 channels (RGBA) target = F.pad(target, (0, 0, 0, 0, 0, 1), value=1.0) - return static_feat, target + return input_rgbd, static_feat, target def train(args): @@ -282,9 +296,10 @@ def train(args): dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Create model - model = CNNv2(kernels=args.kernel_sizes, channels=args.channels).to(device) + model = CNNv2(kernel_size=args.kernel_size, num_layers=args.num_layers).to(device) total_params = sum(p.numel() for p in model.parameters()) - print(f"Model: {args.channels} channels, {args.kernel_sizes} kernels, {total_params} weights") + weights_per_layer = 12 * args.kernel_size * args.kernel_size * 4 + print(f"Model: {args.num_layers} layers, {args.kernel_size}×{args.kernel_size} kernels, {total_params} weights ({weights_per_layer}/layer)") # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -298,12 +313,13 @@ def train(args): model.train() epoch_loss = 0.0 - for static_feat, target in dataloader: + for input_rgbd, static_feat, target in dataloader: + input_rgbd = input_rgbd.to(device) static_feat = static_feat.to(device) target = target.to(device) optimizer.zero_grad() - output = model(static_feat) + output = model(input_rgbd, static_feat) loss = criterion(output, target) loss.backward() optimizer.step() @@ -327,9 +343,9 @@ def train(args): 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss, 'config': { - 'kernels': args.kernel_sizes, - 'channels': args.channels, - 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + 'kernel_size': args.kernel_size, + 'num_layers': args.num_layers, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] } }, checkpoint_path) print(f" → Saved checkpoint: {checkpoint_path}") @@ -361,10 +377,10 @@ def main(): # Mix salient points with random samples for better generalization # Model architecture - parser.add_argument('--kernel-sizes', type=int, nargs=3, default=[1, 3, 5], - help='Kernel sizes for 3 layers (default: 1 3 5)') - parser.add_argument('--channels', type=int, nargs=3, default=[16, 8, 4], - help='Output channels for 3 layers (default: 16 8 4)') + parser.add_argument('--kernel-size', type=int, default=3, + help='Kernel size (uniform for all layers, default: 3)') + parser.add_argument('--num-layers', type=int, default=3, + help='Number of CNN layers (default: 3)') # Training parameters parser.add_argument('--epochs', type=int, default=5000, help='Training epochs') diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl index 1e1704d..5c4b113 100644 --- a/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_compute.wgsl @@ -1,6 +1,6 @@ -// CNN v2 Compute Shader - Storage Buffer Version -// Processes single layer per dispatch with weights from storage buffer -// Multi-layer execution handled by C++ with ping-pong buffers +// CNN v2 Compute Shader - Uniform 12D→4D Architecture +// All layers: input/previous (4D) + static (8D) = 12D → 4 channels +// Storage buffer weights, ping-pong execution // Push constants for layer parameters (passed per dispatch) struct LayerParams { @@ -12,12 +12,12 @@ struct LayerParams { blend_amount: f32, // [0,1] blend with original } -@group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features -@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output (8-channel packed) -@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output +@group(0) @binding(0) var static_features: texture_2d<u32>; // 8D static features (p0-p3 + spatial) +@group(0) @binding(1) var layer_input: texture_2d<u32>; // 4D previous/input (RGBD or prev layer) +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // 4D output @group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights @group(0) @binding(4) var<uniform> params: LayerParams; -@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB input for blending +@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB for blending fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { let packed = textureLoad(static_features, coord, 0); @@ -28,21 +28,19 @@ fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } -fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { +fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> { let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); - let v2 = unpack2x16float(packed.z); - let v3 = unpack2x16float(packed.w); - return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); } -fn pack_channels(values: array<f32, 8>) -> vec4<u32> { +fn pack_channels(values: vec4<f32>) -> vec4<u32> { return vec4<u32>( - pack2x16float(vec2<f32>(values[0], values[1])), - pack2x16float(vec2<f32>(values[2], values[3])), - pack2x16float(vec2<f32>(values[4], values[5])), - pack2x16float(vec2<f32>(values[6], values[7])) + pack2x16float(vec2<f32>(values.x, values.y)), + pack2x16float(vec2<f32>(values.z, values.w)), + 0u, // Unused + 0u // Unused ); } @@ -68,19 +66,19 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { } let kernel_size = params.kernel_size; - let in_channels = params.in_channels; - let out_channels = params.out_channels; + let in_channels = params.in_channels; // Always 12 (4 prev + 8 static) + let out_channels = params.out_channels; // Always 4 let weight_offset = params.weight_offset; let is_output = params.is_output_layer != 0u; let kernel_radius = i32(kernel_size / 2u); - // Load static features (always 8D) + // Load static features (8D) and previous/input layer (4D) let static_feat = unpack_static_features(coord); - // Convolution per output channel - var output: array<f32, 8>; - for (var c: u32 = 0u; c < out_channels && c < 8u; c++) { + // Convolution: 12D input → 4D output + var output: vec4<f32> = vec4<f32>(0.0); + for (var c: u32 = 0u; c < 4u; c++) { var sum: f32 = 0.0; // Convolve over kernel @@ -94,55 +92,49 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { clamp(sample_coord.y, 0, i32(dims.y) - 1) ); - // Load input features at this spatial location + // Load features at this spatial location let static_local = unpack_static_features(clamped); - let layer_local = unpack_layer_channels(clamped); + let layer_local = unpack_layer_channels(clamped); // 4D // Weight index calculation let ky_idx = u32(ky + kernel_radius); let kx_idx = u32(kx + kernel_radius); let spatial_idx = ky_idx * kernel_size + kx_idx; - // Accumulate: static features (always 8 channels) - for (var i: u32 = 0u; i < 8u; i++) { + // Accumulate: previous/input channels (4D) + for (var i: u32 = 0u; i < 4u; i++) { let w_idx = weight_offset + - c * in_channels * kernel_size * kernel_size + + c * 12u * kernel_size * kernel_size + i * kernel_size * kernel_size + spatial_idx; - sum += get_weight(w_idx) * static_local[i]; + sum += get_weight(w_idx) * layer_local[i]; } - // Accumulate: previous layer channels (in_channels - 8) - let prev_channels = in_channels - 8u; - for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) { + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { let w_idx = weight_offset + - c * in_channels * kernel_size * kernel_size + - (8u + i) * kernel_size * kernel_size + spatial_idx; - sum += get_weight(w_idx) * layer_local[i]; + c * 12u * kernel_size * kernel_size + + (4u + i) * kernel_size * kernel_size + spatial_idx; + sum += get_weight(w_idx) * static_local[i]; } } } // Activation if (is_output) { - output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation + output[c] = clamp(sum, 0.0, 1.0); } else { output[c] = max(0.0, sum); // ReLU } } - // Zero unused channels - for (var c: u32 = out_channels; c < 8u; c++) { - output[c] = 0.0; - } - // Blend with original on final layer if (is_output) { let original = textureLoad(original_input, coord, 0).rgb; - let result_rgb = vec3<f32>(output[0], output[1], output[2]); + let result_rgb = vec3<f32>(output.x, output.y, output.z); let blended = mix(original, result_rgb, params.blend_amount); - output[0] = blended.r; - output[1] = blended.g; - output[2] = blended.b; + output.x = blended.r; + output.y = blended.g; + output.z = blended.b; } textureStore(output_tex, coord, pack_channels(output)); diff --git a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl index dd07f19..7a9e6de 100644 --- a/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl +++ b/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl @@ -1,5 +1,7 @@ // CNN v2 Static Features Compute Shader -// Generates 7D features + bias: [R, G, B, D, uv.x, uv.y, sin10_x, 1.0] +// Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] +// p0-p3: Parametric features (currently RGBD from mip0, could be mip1/2, gradients, etc.) +// Note: Input image RGBD (mip0) fed separately to Layer 0 @group(0) @binding(0) var input_tex: texture_2d<f32>; @group(0) @binding(1) var input_tex_mip1: texture_2d<f32>; @@ -16,14 +18,14 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { return; } - // Sample RGBA from mip 0 + // Parametric features (p0-p3) + // TODO: Experiment with mip1 grayscale, Sobel gradients, etc. + // For now, use RGBD from mip 0 (same as input, but could differ) let rgba = textureLoad(input_tex, coord, 0); - let r = rgba.r; - let g = rgba.g; - let b = rgba.b; - - // Sample depth - let d = textureLoad(depth_tex, coord, 0).r; + let p0 = rgba.r; + let p1 = rgba.g; + let p2 = rgba.b; + let p3 = textureLoad(depth_tex, coord, 0).r; // UV coordinates (normalized [0,1], bottom-left origin) let uv_x = f32(coord.x) / f32(dims.x); @@ -36,9 +38,10 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { let bias = 1.0; // Pack 8×f16 into 4×u32 (rgba32uint) + // [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias] let packed = vec4<u32>( - pack2x16float(vec2<f32>(r, g)), - pack2x16float(vec2<f32>(b, d)), + pack2x16float(vec2<f32>(p0, p1)), + pack2x16float(vec2<f32>(p2, p3)), pack2x16float(vec2<f32>(uv_x, uv_y)), pack2x16float(vec2<f32>(sin10_x, bias)) ); |
