diff options
Diffstat (limited to 'cnn_v2/docs/CNN_V2.md')
| -rw-r--r-- | cnn_v2/docs/CNN_V2.md | 813 |
1 files changed, 813 insertions, 0 deletions
diff --git a/cnn_v2/docs/CNN_V2.md b/cnn_v2/docs/CNN_V2.md new file mode 100644 index 0000000..b7fd6f8 --- /dev/null +++ b/cnn_v2/docs/CNN_V2.md @@ -0,0 +1,813 @@ +# CNN v2: Parametric Static Features + +**Technical Design Document** + +--- + +## Overview + +CNN v2 extends the original CNN post-processing effect with parametric static features, enabling richer spatial and frequency-domain inputs for improved visual quality. + +**Key improvements over v1:** +- 7D static feature input (vs 4D RGB) +- Multi-frequency position encoding (NeRF-style) +- Configurable mip-level for p0-p3 parametric features (0-3) +- Per-layer configurable kernel sizes (1×1, 3×3, 5×5) +- Variable channel counts per layer +- Float16 weight storage (~3.2 KB for 3-layer model) +- Bias integrated as static feature dimension +- Storage buffer architecture (dynamic layer count) +- Binary weight format v2 for runtime loading +- Sigmoid activation for layer 0 and final layer (smooth [0,1] mapping) + +**Status:** ✅ Complete. Sigmoid activation, stable training, validation tools operational. + +**Breaking Change:** +- Models trained with `clamp()` incompatible. Retrain required. + +**TODO:** +- 8-bit quantization with QAT for 2× size reduction (~1.6 KB) + +--- + +## Architecture + +### Pipeline Overview + +``` +Input RGBD → Static Features Compute → CNN Layers → Output RGBA + └─ computed once/frame ─┘ └─ multi-pass ─┘ +``` + +**Detailed Data Flow:** + +``` + ┌─────────────────────────────────────────┐ + │ Static Features (computed once) │ + │ 8D: p0,p1,p2,p3,uv_x,uv_y,sin10x,bias │ + └──────────────┬──────────────────────────┘ + │ + │ 8D (broadcast to all layers) + ├───────────────────────────┐ + │ │ + ┌──────────────┐ │ │ + │ Input RGBD │──────────────┤ │ + │ 4D │ 4D │ │ + └──────────────┘ │ │ + ▼ │ + ┌────────────┐ │ + │ Layer 0 │ (12D input) │ + │ (CNN) │ = 4D + 8D │ + │ 12D → 4D │ │ + └─────┬──────┘ │ + │ 4D output │ + │ │ + ├───────────────────────────┘ + │ │ + ▼ │ + ┌────────────┐ │ + │ Layer 1 │ (12D input) │ + │ (CNN) │ = 4D + 8D │ + │ 12D → 4D │ │ + └─────┬──────┘ │ + │ 4D output │ + │ │ + ├───────────────────────────┘ + ▼ │ + ... │ + │ │ + ▼ │ + ┌────────────┐ │ + │ Layer N │ (12D input) │ + │ (output) │◄──────────────────┘ + │ 12D → 4D │ + └─────┬──────┘ + │ 4D (RGBA) + ▼ + Output +``` + +**Key Points:** +- Static features computed once, broadcast to all CNN layers +- Each layer: previous 4D output + 8D static → 12D input → 4D output +- Ping-pong buffering between layers +- Layer 0 special case: uses input RGBD instead of previous layer output + +**Static Features Texture:** +- Name: `static_features` +- Format: `texture_storage_2d<rgba32uint, write>` (4×u32) +- Data: 8 float16 values packed via `pack2x16float()` +- Computed once per frame, read by all CNN layers +- Lifetime: Entire frame (all CNN layer passes) + +**CNN Layers:** +- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels +- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels +- All layers: uniform 12D input, 4D output (ping-pong buffer) +- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs) + +**Activation Functions:** +- Layer 0 & final layer: `sigmoid(x)` for smooth [0,1] mapping +- Middle layers: `ReLU` (max(0, x)) +- Rationale: Sigmoid prevents gradient blocking at boundaries, enabling better convergence +- Breaking change: Models trained with `clamp(x, 0, 1)` are incompatible, retrain required + +--- + +## Static Features (7D + 1 bias) + +### Feature Layout + +**8 float16 values per pixel:** + +```wgsl +// Slot 0-3: Parametric features (p0, p1, p2, p3) +// Sampled from configurable mip level (0=original, 1=half, 2=quarter, 3=eighth) +// Training sets mip_level via --mip-level flag, stored in binary format v2 +let p0 = ...; // RGB.r from selected mip level +let p1 = ...; // RGB.g from selected mip level +let p2 = ...; // RGB.b from selected mip level +let p3 = ...; // Depth or RGB channel from mip level + +// Slot 4-5: UV coordinates (normalized screen space) +let uv_x = coord.x / resolution.x; // Horizontal position [0,1] +let uv_y = coord.y / resolution.y; // Vertical position [0,1] + +// Slot 6: Multi-frequency position encoding +let sin20_y = sin(20.0 * uv_y); // Periodic feature (frequency=20, vertical) + +// Slot 7: Bias dimension (always 1.0) +let bias = 1.0; // Learned bias per output channel + +// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(20*uv.y), 1.0] +``` + +### Input Channel Mapping + +**Weight tensor layout (12 input channels per layer):** + +| Input Channel | Feature | Description | +|--------------|---------|-------------| +| 0-3 | Previous layer output | 4D RGBA from prior CNN layer (or input RGBD for Layer 0) | +| 4-11 | Static features | 8D: p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias | + +**Static feature channel details:** +- Channel 4 → p0 (RGB.r from mip level) +- Channel 5 → p1 (RGB.g from mip level) +- Channel 6 → p2 (RGB.b from mip level) +- Channel 7 → p3 (depth or RGB channel from mip level) +- Channel 8 → p4 (uv_x: normalized horizontal position) +- Channel 9 → p5 (uv_y: normalized vertical position) +- Channel 10 → p6 (sin(20*uv_y): periodic encoding) +- Channel 11 → p7 (bias: constant 1.0) + +**Note:** When generating identity weights, p4-p7 correspond to input channels 8-11, not 4-7. + +### Feature Rationale + +| Feature | Dimension | Purpose | Priority | +|---------|-----------|---------|----------| +| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential | +| UV coords | 2D | Spatial position awareness | Essential | +| sin(20\*uv.y) | 1D | Periodic position encoding (vertical) | Medium | +| Bias | 1D | Learned bias (standard NN) | Essential | + +**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output. + +**Why bias as static feature:** +- Simpler shader code (single weight array) +- Standard NN formulation: y = Wx (x includes bias term) +- Saves 56-112 bytes (no separate bias buffer) +- 7 features sufficient for initial implementation + +### Future Feature Extensions + +**Option: Additional encodings:** +- `sin(40*uv.y)` - Higher frequency encoding +- `gray_mip1` - Multi-scale luminance +- `dx`, `dy` - Sobel gradients +- `variance` - Local texture measure +- `laplacian` - Edge detection + +**Option: uint8 packing (16+ features):** +```wgsl +// texture_storage_2d<rgba8unorm> stores 16 uint8 values +// Trade precision for feature count +// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y, +// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, var, bias] +``` +Requires quantization-aware training. + +--- + +## Layer Structure + +### Example 3-Layer Network + +``` +Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output RGBA) +``` + +**Output:** 4 channels (RGBA). Training targets preserve alpha from target images. + +### Weight Calculations + +**Per-layer weights (uniform 12D→4D, 3×3 kernels):** +``` +Layer 0: 12 × 3 × 3 × 4 = 432 weights +Layer 1: 12 × 3 × 3 × 4 = 432 weights +Layer 2: 12 × 3 × 3 × 4 = 432 weights +Total: 1296 weights +``` + +**Storage sizes:** +- f32: 1296 × 4 = 5,184 bytes (~5.1 KB) +- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended** + +**Comparison to v1:** +- v1: ~800 weights (3.2 KB f32) +- v2: ~1296 weights (2.5 KB f16) +- **Uniform architecture, smaller than v1 f32** + +### Kernel Size Guidelines + +**1×1 kernel (pointwise):** +- No spatial context, channel mixing only +- Weights: `12 × 4 = 48` per layer +- Use for: Fast inference, channel remapping + +**3×3 kernel (standard conv):** +- Local spatial context (recommended) +- Weights: `12 × 9 × 4 = 432` per layer +- Use for: Most layers (balanced quality/size) + +**5×5 kernel (large receptive field):** +- Wide spatial context +- Weights: `12 × 25 × 4 = 1200` per layer +- Use for: Output layer, fine detail enhancement + +### Channel Storage (4×f16 per texel) + +```wgsl +@group(0) @binding(1) var layer_input: texture_2d<u32>; + +fn unpack_channels(coord: vec2<i32>) -> vec4<f32> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); // [ch0, ch1] + let v1 = unpack2x16float(packed.y); // [ch2, ch3] + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); +} + +fn pack_channels(values: vec4<f32>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2(values.x, values.y)), + pack2x16float(vec2(values.z, values.w)), + 0u, // Unused + 0u // Unused + ); +} +``` + +--- + +## Training Workflow + +### Script: `training/train_cnn_v2.py` + +**Static Feature Extraction:** + +```python +def compute_static_features(rgb, depth, mip_level=0): + """Generate parametric features (8D: p0-p3 + spatial). + + Args: + mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res + """ + h, w = rgb.shape[:2] + + # Generate mip level for p0-p3 (downsample then upsample) + if mip_level > 0: + mip_rgb = rgb.copy() + for _ in range(mip_level): + mip_rgb = cv2.pyrDown(mip_rgb) + for _ in range(mip_level): + mip_rgb = cv2.pyrUp(mip_rgb) + if mip_rgb.shape[:2] != (h, w): + mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR) + else: + mip_rgb = rgb + + # Parametric features from mip level + p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth + + # UV coordinates (normalized) + uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) + uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1) + + # Multi-frequency position encoding + sin10_x = np.sin(10.0 * uv_x) + + # Bias dimension (always 1.0) + bias = np.ones_like(p0) + + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] + return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) +``` + +**Network Definition:** + +```python +class CNNv2(nn.Module): + def __init__(self, kernel_sizes, num_layers=3): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * num_layers + self.kernel_sizes = kernel_sizes + self.layers = nn.ModuleList() + + # All layers: 12D input (4 prev + 8 static) → 4D output + for kernel_size in kernel_sizes: + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) + + def forward(self, input_rgbd, static_features): + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.sigmoid(x) # Soft [0,1] for layer 0 + + # Layer 1+: previous output (4D) + static (8D) = 12D + for i in range(1, len(self.layers)): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < len(self.layers) - 1: + x = F.relu(x) + else: + x = torch.sigmoid(x) # Soft [0,1] for final layer + + return x # RGBA output +``` + +**Training Configuration:** + +```python +# Hyperparameters +kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5]) +num_layers = 3 # Number of CNN layers +mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth +grayscale_loss = False # Compute loss on grayscale (Y) instead of RGBA +learning_rate = 1e-3 +batch_size = 16 +epochs = 5000 + +# Dataset: Input RGB, Target RGBA (preserves alpha channel from image) +# Model outputs RGBA, loss compares all 4 channels (or grayscale if --grayscale-loss) + +# Training loop (standard PyTorch f32) +for epoch in range(epochs): + for rgb_batch, depth_batch, target_batch in dataloader: + # Compute static features (8D) with mip level + static_feat = compute_static_features(rgb_batch, depth_batch, mip_level) + + # Input RGBD (4D) + input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) + + # Forward pass + output = model(input_rgbd, static_feat) + + # Loss computation (grayscale or RGBA) + if grayscale_loss: + # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B + output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3] + target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3] + loss = criterion(output_gray, target_gray) + else: + loss = criterion(output, target_batch) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() +``` + +**Checkpoint Format:** + +```python +torch.save({ + 'state_dict': model.state_dict(), # f32 weights + 'config': { + 'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes + 'num_layers': 3, + 'mip_level': 0, # Mip level used for p0-p3 + 'grayscale_loss': False, # Whether grayscale loss was used + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + }, + 'epoch': epoch, + 'loss': loss.item() +}, f'checkpoints/checkpoint_epoch_{epoch}.pth') +``` + +--- + +## Export Workflow + +### Script: `training/export_cnn_v2_shader.py` + +**Process:** +1. Load checkpoint (f32 PyTorch weights) +2. Extract layer configs (kernels, channels) +3. Quantize weights to float16: `weights_f16 = weights_f32.astype(np.float16)` +4. Generate WGSL shader per layer +5. Write to `workspaces/<workspace>/shaders/cnn_v2/cnn_v2_*.wgsl` + +**Example Generated Shader:** + +```wgsl +// cnn_v2_layer_0.wgsl - Auto-generated from checkpoint_epoch_5000.pth + +const KERNEL_SIZE: u32 = 1u; +const IN_CHANNELS: u32 = 8u; // 7 features + bias +const OUT_CHANNELS: u32 = 16u; + +// Weights quantized to float16 (stored as f32 in shader) +const weights: array<f32, 128> = array( + 0.123047, -0.089844, 0.234375, 0.456055, ... +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var output_texture: texture_storage_2d<rgba32uint, write>; + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + // Load static features (8D) + let static_feat = get_static_features(vec2<i32>(id.xy)); + + // Convolution (1×1 kernel = pointwise) + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < IN_CHANNELS; k++) { + sum += weights[c * IN_CHANNELS + k] * static_feat[k]; + } + output[c] = max(0.0, sum); // ReLU activation + } + + // Pack and store (8×f16 per texel) + textureStore(output_texture, vec2<i32>(id.xy), pack_f16x8(output)); +} +``` + +**Float16 Quantization:** +- Training uses f32 throughout (PyTorch standard) +- Export converts to np.float16, then back to f32 for WGSL literals +- **Expected discrepancy:** <0.1% MSE (acceptable) +- Validation via HTML tool (see below) + +--- + +## Validation Workflow + +### HTML Tool: `tools/cnn_v2_test/index.html` + +**WebGPU-based testing tool** with layer visualization. + +**Usage:** +1. Open `tools/cnn_v2_test/index.html` in browser +2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`) +3. Drop PNG test image +4. View results with layer inspection + +**Features:** +- Live CNN inference with WebGPU +- Layer-by-layer visualization (static features + all CNN layers) +- Weight visualization (per-layer kernels) +- View modes: CNN output, original, diff (×10) +- Blend control for comparing with original + +**Export weights:** +```bash +./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \ + --output-weights workspaces/main/cnn_v2_weights.bin +``` + +See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation + +--- + +## Implementation Checklist + +### Phase 1: Shaders (Core Infrastructure) + +- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl` - Static features compute + - [ ] RGBD sampling from framebuffer + - [ ] UV coordinate calculation + - [ ] sin(10\*uv.x) computation + - [ ] Bias dimension (constant 1.0) + - [ ] Float16 packing via `pack2x16float()` + - [ ] Output to `texture_storage_2d<rgba32uint>` + +- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl` - Layer template + - [ ] Static features unpacking + - [ ] Previous layer unpacking (8×f16) + - [ ] Convolution implementation (1×1, 3×3, 5×5) + - [ ] ReLU activation + - [ ] Output packing (8×f16) + - [ ] Proper padding handling + +### Phase 2: C++ Effect Class + +- [ ] `src/effects/cnn_v2_effect.h` - Header + - [ ] Class declaration inheriting from `PostProcessEffect` + - [ ] Static features texture member + - [ ] Layer textures vector + - [ ] Pipeline and bind group members + +- [ ] `src/effects/cnn_v2_effect.cc` - Implementation + - [ ] Constructor: Load shaders, create textures + - [ ] `init()`: Create pipelines, bind groups + - [ ] `render()`: Multi-pass execution + - [ ] Pass 0: Compute static features + - [ ] Pass 1-N: CNN layers + - [ ] Final: Composite to output + - [ ] Proper resource cleanup + +- [ ] Integration + - [ ] Add to `src/gpu/demo_effects.h` includes + - [ ] Add `cnn_v2_effect.cc` to `CMakeLists.txt` (headless + normal) + - [ ] Add shaders to `workspaces/main/assets.txt` + - [ ] Add to `src/tests/gpu/test_demo_effects.cc` + +### Phase 3: Training Pipeline + +- [ ] `training/train_cnn_v2.py` - Training script + - [ ] Static feature extraction function + - [ ] CNNv2 PyTorch model class + - [ ] Patch-based dataloader + - [ ] Training loop with checkpointing + - [ ] Command-line argument parsing + - [ ] Inference mode (ground truth generation) + +- [ ] `training/export_cnn_v2_shader.py` - Export script + - [ ] Checkpoint loading + - [ ] Weight extraction and f16 quantization + - [ ] Per-layer WGSL generation + - [ ] File output to workspace shaders/ + - [ ] Metadata preservation + +### Phase 4: Tools & Validation + +- [x] HTML validation tool - WebGPU inference with layer visualization + - [ ] Command-line argument parsing + - [ ] Shader export orchestration + - [ ] Build orchestration + - [ ] Batch image processing + - [ ] Results display + +- [ ] `src/tools/cnn_test_main.cc` - Tool updates + - [ ] Add `--cnn-version v2` flag + - [ ] CNNv2Effect instantiation path + - [ ] Static features pass execution + - [ ] Multi-layer processing + +### Phase 5: Documentation + +- [ ] `doc/HOWTO.md` - Usage guide + - [ ] Training section (CNN v2) + - [ ] Export section + - [ ] Validation section + - [ ] Examples + +- [ ] `README.md` - Project overview update + - [ ] Mention CNN v2 capability + +--- + +## File Structure + +### New Files + +``` +# Shaders (generated by export script) +workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl # Static features compute +workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl # Input layer (generated) +workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl # Inner layer (generated) +workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl # Output layer (generated) + +# C++ implementation +src/effects/cnn_v2_effect.h # Effect class header +src/effects/cnn_v2_effect.cc # Effect implementation + +# Python training/export +training/train_cnn_v2.py # Training script +training/export_cnn_v2_shader.py # Shader generator +training/validation/ # Test images directory + +# Validation +tools/cnn_v2_test/index.html # WebGPU validation tool + +# Documentation +doc/CNN_V2.md # This file +``` + +### Modified Files + +``` +src/gpu/demo_effects.h # Add CNNv2Effect include +CMakeLists.txt # Add cnn_v2_effect.cc +workspaces/main/assets.txt # Add cnn_v2 shaders +workspaces/main/timeline.seq # Optional: add CNNv2Effect +src/tests/gpu/test_demo_effects.cc # Add CNNv2 test case +src/tools/cnn_test_main.cc # Add --cnn-version v2 +doc/HOWTO.md # Add CNN v2 sections +TODO.md # Add CNN v2 task +``` + +### Unchanged (v1 Preserved) + +``` +training/train_cnn.py # Original training +src/effects/cnn_effect.* # Original effect +workspaces/main/shaders/cnn_*.wgsl # Original v1 shaders +``` + +--- + +## Performance Characteristics + +### Static Features Compute +- **Cost:** ~0.1ms @ 1080p +- **Frequency:** Once per frame +- **Operations:** sin(), texture sampling, packing + +### CNN Layers (Example 3-layer) +- **Layer0 (1×1, 8→16):** ~0.3ms +- **Layer1 (3×3, 23→8):** ~0.8ms +- **Layer2 (5×5, 15→4):** ~1.2ms +- **Total:** ~2.4ms @ 1080p + +### Memory Usage +- Static features: 1920×1080×8×2 = 33 MB (f16) +- Layer buffers: 1920×1080×16×2 = 66 MB (max 16 channels) +- Weights: ~6.4 KB (f16, in shader code) +- **Total GPU memory:** ~100 MB + +--- + +## Size Budget + +### CNN v1 vs v2 + +| Metric | v1 | v2 | Delta | +|--------|----|----|-------| +| Weights (count) | 800 | 3268 | +2468 | +| Storage (f32) | 3.2 KB | 13.1 KB | +9.9 KB | +| Storage (f16) | N/A | 6.5 KB | +6.5 KB | +| Shader code | ~500 lines | ~800 lines | +300 lines | + +### Mitigation Strategies + +**Reduce channels:** +- [16,8,4] → [8,4,4] saves ~50% weights +- [16,8,4] → [4,4,4] saves ~60% weights + +**Smaller kernels:** +- [1,3,5] → [1,3,3] saves ~30% weights +- [1,3,5] → [1,1,3] saves ~50% weights + +**Quantization:** +- int8 weights: saves 75% (requires QAT training) +- 4-bit weights: saves 87.5% (extreme, needs research) + +**Target:** Keep CNN v2 under 10 KB for 64k demo constraint + +--- + +## Future Extensions + +### Flexible Feature Layout (Binary Format v3) + +**TODO:** Support arbitrary feature vector layouts and ordering in binary format. + +**Current Limitation:** +- Feature layout hardcoded: `[p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]` +- Shader must match training script exactly +- Experimentation requires shader recompilation + +**Proposed Enhancement:** +- Add feature descriptor to binary format header +- Specify feature types, sources, and ordering +- Runtime shader generation or dynamic feature indexing +- Examples: `[R, G, B, dx, dy, uv_x, bias]` or `[mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]` + +**Benefits:** +- Training experiments without C++/shader changes +- A/B test different feature combinations +- Single binary format, multiple architectures +- Faster iteration on feature engineering + +**Implementation Options:** +1. **Static approach:** Generate shader code from descriptor at load time +2. **Dynamic approach:** Array-based indexing with feature map uniform +3. **Hybrid:** Precompile common layouts, fallback to dynamic + +See `doc/CNN_V2_BINARY_FORMAT.md` for proposed descriptor format. + +--- + +### More Features (uint8 Packing) + +```wgsl +// 16 uint8 features per texel (texture_storage_2d<rgba8unorm>) +// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y, +// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, variance, bias] +``` +- Trade precision for quantity +- Requires quantization-aware training + +### Temporal Features + +- Previous frame RGBA (motion awareness) +- Optical flow vectors +- Requires multi-frame buffer + +### Learned Position Encodings + +- Replace hand-crafted sin(10\*uv) with learned embeddings +- Requires separate embedding network +- Similar to NeRF position encoding + +### Dynamic Architecture + +- Runtime kernel size selection based on scene +- Conditional layer execution (skip connections) +- Layer pruning for performance + +--- + +## References + +- **v1 Implementation:** `src/effects/cnn_effect.*` +- **Training Guide:** `doc/HOWTO.md` (CNN Training section) +- **Test Tool:** `doc/CNN_TEST_TOOL.md` +- **Shader System:** `doc/SEQUENCE.md` +- **Size Measurement:** `doc/SIZE_MEASUREMENT.md` + +--- + +## Appendix: Design Decisions + +### Why Bias as Static Feature? + +**Alternatives considered:** +1. Separate bias array per layer (Option B) +2. Bias as static feature = 1.0 (Option A, chosen) + +**Decision rationale:** +- Simpler shader code (fewer bindings) +- Standard NN formulation (augmented input) +- Saves 56-112 bytes per model +- 7 features sufficient for v1 implementation +- Can extend to uint8 packing if >7 features needed + +### Why Float16 for Weights? + +**Alternatives considered:** +1. Keep f32 (larger, more accurate) +2. Use f16 (smaller, GPU-native) +3. Use int8 (smallest, needs QAT) + +**Decision rationale:** +- f16 saves 50% vs f32 (critical for 64k target) +- GPU-native support (pack2x16float in WGSL) +- <0.1% accuracy loss (acceptable) +- Simpler than int8 quantization + +### Why Multi-Frequency Position Encoding? + +**Inspiration:** NeRF (Neural Radiance Fields) + +**Benefits:** +- Helps network learn high-frequency details +- Better than raw UV coordinates +- Small footprint (1D per frequency) + +**Future:** Add sin(20\*uv), sin(40\*uv) if >7 features available + +--- + +## Related Documentation + +- `doc/CNN_V2_BINARY_FORMAT.md` - Binary weight file specification (.bin format) +- `doc/CNN_V2_WEB_TOOL.md` - WebGPU testing tool with layer visualization +- `doc/CNN_TEST_TOOL.md` - C++ offline validation tool (deprecated) +- `doc/HOWTO.md` - Training and validation workflows + +--- + +**Document Version:** 1.0 +**Last Updated:** 2026-02-12 +**Status:** Design approved, ready for implementation |
