diff options
Diffstat (limited to 'doc/CNN_V2.md')
| -rw-r--r-- | doc/CNN_V2.md | 813 |
1 files changed, 0 insertions, 813 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md deleted file mode 100644 index b7fd6f8..0000000 --- a/doc/CNN_V2.md +++ /dev/null @@ -1,813 +0,0 @@ -# CNN v2: Parametric Static Features - -**Technical Design Document** - ---- - -## Overview - -CNN v2 extends the original CNN post-processing effect with parametric static features, enabling richer spatial and frequency-domain inputs for improved visual quality. - -**Key improvements over v1:** -- 7D static feature input (vs 4D RGB) -- Multi-frequency position encoding (NeRF-style) -- Configurable mip-level for p0-p3 parametric features (0-3) -- Per-layer configurable kernel sizes (1×1, 3×3, 5×5) -- Variable channel counts per layer -- Float16 weight storage (~3.2 KB for 3-layer model) -- Bias integrated as static feature dimension -- Storage buffer architecture (dynamic layer count) -- Binary weight format v2 for runtime loading -- Sigmoid activation for layer 0 and final layer (smooth [0,1] mapping) - -**Status:** ✅ Complete. Sigmoid activation, stable training, validation tools operational. - -**Breaking Change:** -- Models trained with `clamp()` incompatible. Retrain required. - -**TODO:** -- 8-bit quantization with QAT for 2× size reduction (~1.6 KB) - ---- - -## Architecture - -### Pipeline Overview - -``` -Input RGBD → Static Features Compute → CNN Layers → Output RGBA - └─ computed once/frame ─┘ └─ multi-pass ─┘ -``` - -**Detailed Data Flow:** - -``` - ┌─────────────────────────────────────────┐ - │ Static Features (computed once) │ - │ 8D: p0,p1,p2,p3,uv_x,uv_y,sin10x,bias │ - └──────────────┬──────────────────────────┘ - │ - │ 8D (broadcast to all layers) - ├───────────────────────────┐ - │ │ - ┌──────────────┐ │ │ - │ Input RGBD │──────────────┤ │ - │ 4D │ 4D │ │ - └──────────────┘ │ │ - ▼ │ - ┌────────────┐ │ - │ Layer 0 │ (12D input) │ - │ (CNN) │ = 4D + 8D │ - │ 12D → 4D │ │ - └─────┬──────┘ │ - │ 4D output │ - │ │ - ├───────────────────────────┘ - │ │ - ▼ │ - ┌────────────┐ │ - │ Layer 1 │ (12D input) │ - │ (CNN) │ = 4D + 8D │ - │ 12D → 4D │ │ - └─────┬──────┘ │ - │ 4D output │ - │ │ - ├───────────────────────────┘ - ▼ │ - ... │ - │ │ - ▼ │ - ┌────────────┐ │ - │ Layer N │ (12D input) │ - │ (output) │◄──────────────────┘ - │ 12D → 4D │ - └─────┬──────┘ - │ 4D (RGBA) - ▼ - Output -``` - -**Key Points:** -- Static features computed once, broadcast to all CNN layers -- Each layer: previous 4D output + 8D static → 12D input → 4D output -- Ping-pong buffering between layers -- Layer 0 special case: uses input RGBD instead of previous layer output - -**Static Features Texture:** -- Name: `static_features` -- Format: `texture_storage_2d<rgba32uint, write>` (4×u32) -- Data: 8 float16 values packed via `pack2x16float()` -- Computed once per frame, read by all CNN layers -- Lifetime: Entire frame (all CNN layer passes) - -**CNN Layers:** -- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels -- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels -- All layers: uniform 12D input, 4D output (ping-pong buffer) -- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs) - -**Activation Functions:** -- Layer 0 & final layer: `sigmoid(x)` for smooth [0,1] mapping -- Middle layers: `ReLU` (max(0, x)) -- Rationale: Sigmoid prevents gradient blocking at boundaries, enabling better convergence -- Breaking change: Models trained with `clamp(x, 0, 1)` are incompatible, retrain required - ---- - -## Static Features (7D + 1 bias) - -### Feature Layout - -**8 float16 values per pixel:** - -```wgsl -// Slot 0-3: Parametric features (p0, p1, p2, p3) -// Sampled from configurable mip level (0=original, 1=half, 2=quarter, 3=eighth) -// Training sets mip_level via --mip-level flag, stored in binary format v2 -let p0 = ...; // RGB.r from selected mip level -let p1 = ...; // RGB.g from selected mip level -let p2 = ...; // RGB.b from selected mip level -let p3 = ...; // Depth or RGB channel from mip level - -// Slot 4-5: UV coordinates (normalized screen space) -let uv_x = coord.x / resolution.x; // Horizontal position [0,1] -let uv_y = coord.y / resolution.y; // Vertical position [0,1] - -// Slot 6: Multi-frequency position encoding -let sin20_y = sin(20.0 * uv_y); // Periodic feature (frequency=20, vertical) - -// Slot 7: Bias dimension (always 1.0) -let bias = 1.0; // Learned bias per output channel - -// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(20*uv.y), 1.0] -``` - -### Input Channel Mapping - -**Weight tensor layout (12 input channels per layer):** - -| Input Channel | Feature | Description | -|--------------|---------|-------------| -| 0-3 | Previous layer output | 4D RGBA from prior CNN layer (or input RGBD for Layer 0) | -| 4-11 | Static features | 8D: p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias | - -**Static feature channel details:** -- Channel 4 → p0 (RGB.r from mip level) -- Channel 5 → p1 (RGB.g from mip level) -- Channel 6 → p2 (RGB.b from mip level) -- Channel 7 → p3 (depth or RGB channel from mip level) -- Channel 8 → p4 (uv_x: normalized horizontal position) -- Channel 9 → p5 (uv_y: normalized vertical position) -- Channel 10 → p6 (sin(20*uv_y): periodic encoding) -- Channel 11 → p7 (bias: constant 1.0) - -**Note:** When generating identity weights, p4-p7 correspond to input channels 8-11, not 4-7. - -### Feature Rationale - -| Feature | Dimension | Purpose | Priority | -|---------|-----------|---------|----------| -| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential | -| UV coords | 2D | Spatial position awareness | Essential | -| sin(20\*uv.y) | 1D | Periodic position encoding (vertical) | Medium | -| Bias | 1D | Learned bias (standard NN) | Essential | - -**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output. - -**Why bias as static feature:** -- Simpler shader code (single weight array) -- Standard NN formulation: y = Wx (x includes bias term) -- Saves 56-112 bytes (no separate bias buffer) -- 7 features sufficient for initial implementation - -### Future Feature Extensions - -**Option: Additional encodings:** -- `sin(40*uv.y)` - Higher frequency encoding -- `gray_mip1` - Multi-scale luminance -- `dx`, `dy` - Sobel gradients -- `variance` - Local texture measure -- `laplacian` - Edge detection - -**Option: uint8 packing (16+ features):** -```wgsl -// texture_storage_2d<rgba8unorm> stores 16 uint8 values -// Trade precision for feature count -// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y, -// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, var, bias] -``` -Requires quantization-aware training. - ---- - -## Layer Structure - -### Example 3-Layer Network - -``` -Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel) -Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel) -Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output RGBA) -``` - -**Output:** 4 channels (RGBA). Training targets preserve alpha from target images. - -### Weight Calculations - -**Per-layer weights (uniform 12D→4D, 3×3 kernels):** -``` -Layer 0: 12 × 3 × 3 × 4 = 432 weights -Layer 1: 12 × 3 × 3 × 4 = 432 weights -Layer 2: 12 × 3 × 3 × 4 = 432 weights -Total: 1296 weights -``` - -**Storage sizes:** -- f32: 1296 × 4 = 5,184 bytes (~5.1 KB) -- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended** - -**Comparison to v1:** -- v1: ~800 weights (3.2 KB f32) -- v2: ~1296 weights (2.5 KB f16) -- **Uniform architecture, smaller than v1 f32** - -### Kernel Size Guidelines - -**1×1 kernel (pointwise):** -- No spatial context, channel mixing only -- Weights: `12 × 4 = 48` per layer -- Use for: Fast inference, channel remapping - -**3×3 kernel (standard conv):** -- Local spatial context (recommended) -- Weights: `12 × 9 × 4 = 432` per layer -- Use for: Most layers (balanced quality/size) - -**5×5 kernel (large receptive field):** -- Wide spatial context -- Weights: `12 × 25 × 4 = 1200` per layer -- Use for: Output layer, fine detail enhancement - -### Channel Storage (4×f16 per texel) - -```wgsl -@group(0) @binding(1) var layer_input: texture_2d<u32>; - -fn unpack_channels(coord: vec2<i32>) -> vec4<f32> { - let packed = textureLoad(layer_input, coord, 0); - let v0 = unpack2x16float(packed.x); // [ch0, ch1] - let v1 = unpack2x16float(packed.y); // [ch2, ch3] - return vec4<f32>(v0.x, v0.y, v1.x, v1.y); -} - -fn pack_channels(values: vec4<f32>) -> vec4<u32> { - return vec4<u32>( - pack2x16float(vec2(values.x, values.y)), - pack2x16float(vec2(values.z, values.w)), - 0u, // Unused - 0u // Unused - ); -} -``` - ---- - -## Training Workflow - -### Script: `training/train_cnn_v2.py` - -**Static Feature Extraction:** - -```python -def compute_static_features(rgb, depth, mip_level=0): - """Generate parametric features (8D: p0-p3 + spatial). - - Args: - mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res - """ - h, w = rgb.shape[:2] - - # Generate mip level for p0-p3 (downsample then upsample) - if mip_level > 0: - mip_rgb = rgb.copy() - for _ in range(mip_level): - mip_rgb = cv2.pyrDown(mip_rgb) - for _ in range(mip_level): - mip_rgb = cv2.pyrUp(mip_rgb) - if mip_rgb.shape[:2] != (h, w): - mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) - else: - mip_rgb = rgb - - # Parametric features from mip level - p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth - - # UV coordinates (normalized) - uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) - uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1) - - # Multi-frequency position encoding - sin10_x = np.sin(10.0 * uv_x) - - # Bias dimension (always 1.0) - bias = np.ones_like(p0) - - # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] - return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) -``` - -**Network Definition:** - -```python -class CNNv2(nn.Module): - def __init__(self, kernel_sizes, num_layers=3): - super().__init__() - if isinstance(kernel_sizes, int): - kernel_sizes = [kernel_sizes] * num_layers - self.kernel_sizes = kernel_sizes - self.layers = nn.ModuleList() - - # All layers: 12D input (4 prev + 8 static) → 4D output - for kernel_size in kernel_sizes: - self.layers.append( - nn.Conv2d(12, 4, kernel_size=kernel_size, - padding=kernel_size//2, bias=False) - ) - - def forward(self, input_rgbd, static_features): - # Layer 0: input RGBD (4D) + static (8D) = 12D - x = torch.cat([input_rgbd, static_features], dim=1) - x = self.layers[0](x) - x = torch.sigmoid(x) # Soft [0,1] for layer 0 - - # Layer 1+: previous output (4D) + static (8D) = 12D - for i in range(1, len(self.layers)): - x_input = torch.cat([x, static_features], dim=1) - x = self.layers[i](x_input) - if i < len(self.layers) - 1: - x = F.relu(x) - else: - x = torch.sigmoid(x) # Soft [0,1] for final layer - - return x # RGBA output -``` - -**Training Configuration:** - -```python -# Hyperparameters -kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5]) -num_layers = 3 # Number of CNN layers -mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth -grayscale_loss = False # Compute loss on grayscale (Y) instead of RGBA -learning_rate = 1e-3 -batch_size = 16 -epochs = 5000 - -# Dataset: Input RGB, Target RGBA (preserves alpha channel from image) -# Model outputs RGBA, loss compares all 4 channels (or grayscale if --grayscale-loss) - -# Training loop (standard PyTorch f32) -for epoch in range(epochs): - for rgb_batch, depth_batch, target_batch in dataloader: - # Compute static features (8D) with mip level - static_feat = compute_static_features(rgb_batch, depth_batch, mip_level) - - # Input RGBD (4D) - input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) - - # Forward pass - output = model(input_rgbd, static_feat) - - # Loss computation (grayscale or RGBA) - if grayscale_loss: - # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B - output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3] - target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3] - loss = criterion(output_gray, target_gray) - else: - loss = criterion(output, target_batch) - - # Backward pass - optimizer.zero_grad() - loss.backward() - optimizer.step() -``` - -**Checkpoint Format:** - -```python -torch.save({ - 'state_dict': model.state_dict(), # f32 weights - 'config': { - 'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes - 'num_layers': 3, - 'mip_level': 0, # Mip level used for p0-p3 - 'grayscale_loss': False, # Whether grayscale loss was used - 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] - }, - 'epoch': epoch, - 'loss': loss.item() -}, f'checkpoints/checkpoint_epoch_{epoch}.pth') -``` - ---- - -## Export Workflow - -### Script: `training/export_cnn_v2_shader.py` - -**Process:** -1. Load checkpoint (f32 PyTorch weights) -2. Extract layer configs (kernels, channels) -3. Quantize weights to float16: `weights_f16 = weights_f32.astype(np.float16)` -4. Generate WGSL shader per layer -5. Write to `workspaces/<workspace>/shaders/cnn_v2/cnn_v2_*.wgsl` - -**Example Generated Shader:** - -```wgsl -// cnn_v2_layer_0.wgsl - Auto-generated from checkpoint_epoch_5000.pth - -const KERNEL_SIZE: u32 = 1u; -const IN_CHANNELS: u32 = 8u; // 7 features + bias -const OUT_CHANNELS: u32 = 16u; - -// Weights quantized to float16 (stored as f32 in shader) -const weights: array<f32, 128> = array( - 0.123047, -0.089844, 0.234375, 0.456055, ... -); - -@group(0) @binding(0) var static_features: texture_2d<u32>; -@group(0) @binding(1) var output_texture: texture_storage_2d<rgba32uint, write>; - -@compute @workgroup_size(8, 8) -fn main(@builtin(global_invocation_id) id: vec3<u32>) { - // Load static features (8D) - let static_feat = get_static_features(vec2<i32>(id.xy)); - - // Convolution (1×1 kernel = pointwise) - var output: array<f32, OUT_CHANNELS>; - for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < IN_CHANNELS; k++) { - sum += weights[c * IN_CHANNELS + k] * static_feat[k]; - } - output[c] = max(0.0, sum); // ReLU activation - } - - // Pack and store (8×f16 per texel) - textureStore(output_texture, vec2<i32>(id.xy), pack_f16x8(output)); -} -``` - -**Float16 Quantization:** -- Training uses f32 throughout (PyTorch standard) -- Export converts to np.float16, then back to f32 for WGSL literals -- **Expected discrepancy:** <0.1% MSE (acceptable) -- Validation via HTML tool (see below) - ---- - -## Validation Workflow - -### HTML Tool: `tools/cnn_v2_test/index.html` - -**WebGPU-based testing tool** with layer visualization. - -**Usage:** -1. Open `tools/cnn_v2_test/index.html` in browser -2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`) -3. Drop PNG test image -4. View results with layer inspection - -**Features:** -- Live CNN inference with WebGPU -- Layer-by-layer visualization (static features + all CNN layers) -- Weight visualization (per-layer kernels) -- View modes: CNN output, original, diff (×10) -- Blend control for comparing with original - -**Export weights:** -```bash -./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \ - --output-weights workspaces/main/cnn_v2_weights.bin -``` - -See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation - ---- - -## Implementation Checklist - -### Phase 1: Shaders (Core Infrastructure) - -- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl` - Static features compute - - [ ] RGBD sampling from framebuffer - - [ ] UV coordinate calculation - - [ ] sin(10\*uv.x) computation - - [ ] Bias dimension (constant 1.0) - - [ ] Float16 packing via `pack2x16float()` - - [ ] Output to `texture_storage_2d<rgba32uint>` - -- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl` - Layer template - - [ ] Static features unpacking - - [ ] Previous layer unpacking (8×f16) - - [ ] Convolution implementation (1×1, 3×3, 5×5) - - [ ] ReLU activation - - [ ] Output packing (8×f16) - - [ ] Proper padding handling - -### Phase 2: C++ Effect Class - -- [ ] `src/effects/cnn_v2_effect.h` - Header - - [ ] Class declaration inheriting from `PostProcessEffect` - - [ ] Static features texture member - - [ ] Layer textures vector - - [ ] Pipeline and bind group members - -- [ ] `src/effects/cnn_v2_effect.cc` - Implementation - - [ ] Constructor: Load shaders, create textures - - [ ] `init()`: Create pipelines, bind groups - - [ ] `render()`: Multi-pass execution - - [ ] Pass 0: Compute static features - - [ ] Pass 1-N: CNN layers - - [ ] Final: Composite to output - - [ ] Proper resource cleanup - -- [ ] Integration - - [ ] Add to `src/gpu/demo_effects.h` includes - - [ ] Add `cnn_v2_effect.cc` to `CMakeLists.txt` (headless + normal) - - [ ] Add shaders to `workspaces/main/assets.txt` - - [ ] Add to `src/tests/gpu/test_demo_effects.cc` - -### Phase 3: Training Pipeline - -- [ ] `training/train_cnn_v2.py` - Training script - - [ ] Static feature extraction function - - [ ] CNNv2 PyTorch model class - - [ ] Patch-based dataloader - - [ ] Training loop with checkpointing - - [ ] Command-line argument parsing - - [ ] Inference mode (ground truth generation) - -- [ ] `training/export_cnn_v2_shader.py` - Export script - - [ ] Checkpoint loading - - [ ] Weight extraction and f16 quantization - - [ ] Per-layer WGSL generation - - [ ] File output to workspace shaders/ - - [ ] Metadata preservation - -### Phase 4: Tools & Validation - -- [x] HTML validation tool - WebGPU inference with layer visualization - - [ ] Command-line argument parsing - - [ ] Shader export orchestration - - [ ] Build orchestration - - [ ] Batch image processing - - [ ] Results display - -- [ ] `src/tools/cnn_test_main.cc` - Tool updates - - [ ] Add `--cnn-version v2` flag - - [ ] CNNv2Effect instantiation path - - [ ] Static features pass execution - - [ ] Multi-layer processing - -### Phase 5: Documentation - -- [ ] `doc/HOWTO.md` - Usage guide - - [ ] Training section (CNN v2) - - [ ] Export section - - [ ] Validation section - - [ ] Examples - -- [ ] `README.md` - Project overview update - - [ ] Mention CNN v2 capability - ---- - -## File Structure - -### New Files - -``` -# Shaders (generated by export script) -workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl # Static features compute -workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl # Input layer (generated) -workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl # Inner layer (generated) -workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl # Output layer (generated) - -# C++ implementation -src/effects/cnn_v2_effect.h # Effect class header -src/effects/cnn_v2_effect.cc # Effect implementation - -# Python training/export -training/train_cnn_v2.py # Training script -training/export_cnn_v2_shader.py # Shader generator -training/validation/ # Test images directory - -# Validation -tools/cnn_v2_test/index.html # WebGPU validation tool - -# Documentation -doc/CNN_V2.md # This file -``` - -### Modified Files - -``` -src/gpu/demo_effects.h # Add CNNv2Effect include -CMakeLists.txt # Add cnn_v2_effect.cc -workspaces/main/assets.txt # Add cnn_v2 shaders -workspaces/main/timeline.seq # Optional: add CNNv2Effect -src/tests/gpu/test_demo_effects.cc # Add CNNv2 test case -src/tools/cnn_test_main.cc # Add --cnn-version v2 -doc/HOWTO.md # Add CNN v2 sections -TODO.md # Add CNN v2 task -``` - -### Unchanged (v1 Preserved) - -``` -training/train_cnn.py # Original training -src/effects/cnn_effect.* # Original effect -workspaces/main/shaders/cnn_*.wgsl # Original v1 shaders -``` - ---- - -## Performance Characteristics - -### Static Features Compute -- **Cost:** ~0.1ms @ 1080p -- **Frequency:** Once per frame -- **Operations:** sin(), texture sampling, packing - -### CNN Layers (Example 3-layer) -- **Layer0 (1×1, 8→16):** ~0.3ms -- **Layer1 (3×3, 23→8):** ~0.8ms -- **Layer2 (5×5, 15→4):** ~1.2ms -- **Total:** ~2.4ms @ 1080p - -### Memory Usage -- Static features: 1920×1080×8×2 = 33 MB (f16) -- Layer buffers: 1920×1080×16×2 = 66 MB (max 16 channels) -- Weights: ~6.4 KB (f16, in shader code) -- **Total GPU memory:** ~100 MB - ---- - -## Size Budget - -### CNN v1 vs v2 - -| Metric | v1 | v2 | Delta | -|--------|----|----|-------| -| Weights (count) | 800 | 3268 | +2468 | -| Storage (f32) | 3.2 KB | 13.1 KB | +9.9 KB | -| Storage (f16) | N/A | 6.5 KB | +6.5 KB | -| Shader code | ~500 lines | ~800 lines | +300 lines | - -### Mitigation Strategies - -**Reduce channels:** -- [16,8,4] → [8,4,4] saves ~50% weights -- [16,8,4] → [4,4,4] saves ~60% weights - -**Smaller kernels:** -- [1,3,5] → [1,3,3] saves ~30% weights -- [1,3,5] → [1,1,3] saves ~50% weights - -**Quantization:** -- int8 weights: saves 75% (requires QAT training) -- 4-bit weights: saves 87.5% (extreme, needs research) - -**Target:** Keep CNN v2 under 10 KB for 64k demo constraint - ---- - -## Future Extensions - -### Flexible Feature Layout (Binary Format v3) - -**TODO:** Support arbitrary feature vector layouts and ordering in binary format. - -**Current Limitation:** -- Feature layout hardcoded: `[p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]` -- Shader must match training script exactly -- Experimentation requires shader recompilation - -**Proposed Enhancement:** -- Add feature descriptor to binary format header -- Specify feature types, sources, and ordering -- Runtime shader generation or dynamic feature indexing -- Examples: `[R, G, B, dx, dy, uv_x, bias]` or `[mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]` - -**Benefits:** -- Training experiments without C++/shader changes -- A/B test different feature combinations -- Single binary format, multiple architectures -- Faster iteration on feature engineering - -**Implementation Options:** -1. **Static approach:** Generate shader code from descriptor at load time -2. **Dynamic approach:** Array-based indexing with feature map uniform -3. **Hybrid:** Precompile common layouts, fallback to dynamic - -See `doc/CNN_V2_BINARY_FORMAT.md` for proposed descriptor format. - ---- - -### More Features (uint8 Packing) - -```wgsl -// 16 uint8 features per texel (texture_storage_2d<rgba8unorm>) -// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y, -// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, variance, bias] -``` -- Trade precision for quantity -- Requires quantization-aware training - -### Temporal Features - -- Previous frame RGBA (motion awareness) -- Optical flow vectors -- Requires multi-frame buffer - -### Learned Position Encodings - -- Replace hand-crafted sin(10\*uv) with learned embeddings -- Requires separate embedding network -- Similar to NeRF position encoding - -### Dynamic Architecture - -- Runtime kernel size selection based on scene -- Conditional layer execution (skip connections) -- Layer pruning for performance - ---- - -## References - -- **v1 Implementation:** `src/effects/cnn_effect.*` -- **Training Guide:** `doc/HOWTO.md` (CNN Training section) -- **Test Tool:** `doc/CNN_TEST_TOOL.md` -- **Shader System:** `doc/SEQUENCE.md` -- **Size Measurement:** `doc/SIZE_MEASUREMENT.md` - ---- - -## Appendix: Design Decisions - -### Why Bias as Static Feature? - -**Alternatives considered:** -1. Separate bias array per layer (Option B) -2. Bias as static feature = 1.0 (Option A, chosen) - -**Decision rationale:** -- Simpler shader code (fewer bindings) -- Standard NN formulation (augmented input) -- Saves 56-112 bytes per model -- 7 features sufficient for v1 implementation -- Can extend to uint8 packing if >7 features needed - -### Why Float16 for Weights? - -**Alternatives considered:** -1. Keep f32 (larger, more accurate) -2. Use f16 (smaller, GPU-native) -3. Use int8 (smallest, needs QAT) - -**Decision rationale:** -- f16 saves 50% vs f32 (critical for 64k target) -- GPU-native support (pack2x16float in WGSL) -- <0.1% accuracy loss (acceptable) -- Simpler than int8 quantization - -### Why Multi-Frequency Position Encoding? - -**Inspiration:** NeRF (Neural Radiance Fields) - -**Benefits:** -- Helps network learn high-frequency details -- Better than raw UV coordinates -- Small footprint (1D per frequency) - -**Future:** Add sin(20\*uv), sin(40\*uv) if >7 features available - ---- - -## Related Documentation - -- `doc/CNN_V2_BINARY_FORMAT.md` - Binary weight file specification (.bin format) -- `doc/CNN_V2_WEB_TOOL.md` - WebGPU testing tool with layer visualization -- `doc/CNN_TEST_TOOL.md` - C++ offline validation tool (deprecated) -- `doc/HOWTO.md` - Training and validation workflows - ---- - -**Document Version:** 1.0 -**Last Updated:** 2026-02-12 -**Status:** Design approved, ready for implementation |
