diff options
Diffstat (limited to 'doc/CNN_V2.md')
| -rw-r--r-- | doc/CNN_V2.md | 232 |
1 files changed, 104 insertions, 128 deletions
diff --git a/doc/CNN_V2.md b/doc/CNN_V2.md index 588c3db..4612d7a 100644 --- a/doc/CNN_V2.md +++ b/doc/CNN_V2.md @@ -40,10 +40,10 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA - Lifetime: Entire frame (all CNN layer passes) **CNN Layers:** -- Input Layer: 7D static features → C₀ channels -- Inner Layers: (7D + Cᵢ₋₁) → Cᵢ channels -- Output Layer: (7D + Cₙ) → 4D RGBA -- Storage: `texture_storage_2d<rgba32uint>` (8×f16 per texel recommended) +- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels +- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels +- All layers: uniform 12D input, 4D output (ping-pong buffer) +- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs) --- @@ -54,11 +54,13 @@ Input RGBD → Static Features Compute → CNN Layers → Output RGBA **8 float16 values per pixel:** ```wgsl -// Slot 0-3: RGBD (core pixel data) -let r = rgba.r; // Red channel -let g = rgba.g; // Green channel -let b = rgba.b; // Blue channel -let d = depth; // Depth value +// Slot 0-3: Parametric features (p0, p1, p2, p3) +// Can be: mip1/2 RGBD, grayscale, gradients, etc. +// Distinct from input image RGBD (fed only to Layer 0) +let p0 = ...; // Parametric feature 0 (e.g., mip1.r or grayscale) +let p1 = ...; // Parametric feature 1 +let p2 = ...; // Parametric feature 2 +let p3 = ...; // Parametric feature 3 // Slot 4-5: UV coordinates (normalized screen space) let uv_x = coord.x / resolution.x; // Horizontal position [0,1] @@ -70,18 +72,20 @@ let sin10_x = sin(10.0 * uv_x); // Periodic feature (frequency=10) // Slot 7: Bias dimension (always 1.0) let bias = 1.0; // Learned bias per output channel -// Packed storage: [R, G, B, D, uv.x, uv.y, sin(10*uv.x), 1.0] +// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(10*uv.x), 1.0] ``` ### Feature Rationale | Feature | Dimension | Purpose | Priority | |---------|-----------|---------|----------| -| RGBD | 4D | Core pixel information | Essential | +| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential | | UV coords | 2D | Spatial position awareness | Essential | | sin(10\*uv.x) | 1D | Periodic position encoding | Medium | | Bias | 1D | Learned bias (standard NN) | Essential | +**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output. + **Why bias as static feature:** - Simpler shader code (single weight array) - Standard NN formulation: y = Wx (x includes bias term) @@ -113,68 +117,65 @@ Requires quantization-aware training. ### Example 3-Layer Network ``` -Input: 7D static → 16 channels (1×1 kernel, pointwise) -Layer1: (7+16)D → 8 channels (3×3 kernel, spatial) -Layer2: (7+8)D → 4 channels (5×5 kernel, large receptive field) +Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel) +Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output) ``` ### Weight Calculations -**Per-layer weights:** +**Per-layer weights (uniform 12D→4D, 3×3 kernels):** ``` -Input: 7 × 1 × 1 × 16 = 112 weights -Layer1: (7+16) × 3 × 3 × 8 = 1656 weights -Layer2: (7+8) × 5 × 5 × 4 = 1500 weights -Total: 3268 weights +Layer 0: 12 × 3 × 3 × 4 = 432 weights +Layer 1: 12 × 3 × 3 × 4 = 432 weights +Layer 2: 12 × 3 × 3 × 4 = 432 weights +Total: 1296 weights ``` **Storage sizes:** -- f32: 3268 × 4 = 13,072 bytes (~12.8 KB) -- f16: 3268 × 2 = 6,536 bytes (~6.4 KB) ✓ **recommended** +- f32: 1296 × 4 = 5,184 bytes (~5.1 KB) +- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended** **Comparison to v1:** - v1: ~800 weights (3.2 KB f32) -- v2: ~3268 weights (6.4 KB f16) -- **Growth: 2× size for parametric features** +- v2: ~1296 weights (2.5 KB f16) +- **Uniform architecture, smaller than v1 f32** ### Kernel Size Guidelines **1×1 kernel (pointwise):** - No spatial context, channel mixing only -- Weights: `(7 + C_in) × C_out` -- Use for: Input layer, bottleneck layers +- Weights: `12 × 4 = 48` per layer +- Use for: Fast inference, channel remapping **3×3 kernel (standard conv):** -- Local spatial context -- Weights: `(7 + C_in) × 9 × C_out` -- Use for: Most inner layers +- Local spatial context (recommended) +- Weights: `12 × 9 × 4 = 432` per layer +- Use for: Most layers (balanced quality/size) **5×5 kernel (large receptive field):** - Wide spatial context -- Weights: `(7 + C_in) × 25 × C_out` -- Use for: Output layer, detail enhancement +- Weights: `12 × 25 × 4 = 1200` per layer +- Use for: Output layer, fine detail enhancement -### Channel Storage (8×f16 per texel) +### Channel Storage (4×f16 per texel) ```wgsl @group(0) @binding(1) var layer_input: texture_2d<u32>; -fn unpack_channels(coord: vec2<i32>) -> array<f32, 8> { +fn unpack_channels(coord: vec2<i32>) -> vec4<f32> { let packed = textureLoad(layer_input, coord, 0); - return array( - unpack2x16float(packed.x).x, unpack2x16float(packed.x).y, - unpack2x16float(packed.y).x, unpack2x16float(packed.y).y, - unpack2x16float(packed.z).x, unpack2x16float(packed.z).y, - unpack2x16float(packed.w).x, unpack2x16float(packed.w).y - ); + let v0 = unpack2x16float(packed.x); // [ch0, ch1] + let v1 = unpack2x16float(packed.y); // [ch2, ch3] + return vec4<f32>(v0.x, v0.y, v1.x, v1.y); } -fn pack_channels(values: array<f32, 8>) -> vec4<u32> { - return vec4( - pack2x16float(vec2(values[0], values[1])), - pack2x16float(vec2(values[2], values[3])), - pack2x16float(vec2(values[4], values[5])), - pack2x16float(vec2(values[6], values[7])) +fn pack_channels(values: vec4<f32>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2(values.x, values.y)), + pack2x16float(vec2(values.z, values.w)), + 0u, // Unused + 0u // Unused ); } ``` @@ -189,11 +190,11 @@ fn pack_channels(values: array<f32, 8>) -> vec4<u32> { ```python def compute_static_features(rgb, depth): - """Generate 7D static features + bias dimension.""" + """Generate parametric features (8D: p0-p3 + spatial).""" h, w = rgb.shape[:2] - # RGBD channels - r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + # Parametric features (example: use input RGBD, but could be mips/gradients) + p0, p1, p2, p3 = rgb[..., 0], rgb[..., 1], rgb[..., 2], depth # UV coordinates (normalized) uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0) @@ -203,56 +204,51 @@ def compute_static_features(rgb, depth): sin10_x = np.sin(10.0 * uv_x) # Bias dimension (always 1.0) - bias = np.ones_like(r) + bias = np.ones_like(p0) - # Stack: [R, G, B, D, uv.x, uv.y, sin10_x, bias] - return np.stack([r, g, b, depth, uv_x, uv_y, sin10_x, bias], axis=-1) + # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias] + return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1) ``` **Network Definition:** ```python class CNNv2(nn.Module): - def __init__(self, kernels=[1,3,5], channels=[16,8,4]): + def __init__(self, kernel_size=3, num_layers=3): super().__init__() + self.layers = nn.ModuleList() - # Input layer: 8D (7 features + bias) → channels[0] - self.layer0 = nn.Conv2d(8, channels[0], kernel_size=kernels[0], - padding=kernels[0]//2, bias=False) - - # Inner layers: (7 features + bias + C_prev) → C_next - in_ch_1 = 8 + channels[0] # static + layer0 output - self.layer1 = nn.Conv2d(in_ch_1, channels[1], kernel_size=kernels[1], - padding=kernels[1]//2, bias=False) - - # Output layer: (7 features + bias + C_last) → 4 (RGBA) - in_ch_2 = 8 + channels[1] - self.layer2 = nn.Conv2d(in_ch_2, 4, kernel_size=kernels[2], - padding=kernels[2]//2, bias=False) + # All layers: 12D input (4 prev + 8 static) → 4D output + for i in range(num_layers): + self.layers.append( + nn.Conv2d(12, 4, kernel_size=kernel_size, + padding=kernel_size//2, bias=False) + ) - def forward(self, static_features, layer0_input=None): - # Layer 0: Use full 8D static features (includes bias) - x0 = self.layer0(static_features) - x0 = F.relu(x0) + def forward(self, input_rgbd, static_features): + # Layer 0: input RGBD (4D) + static (8D) = 12D + x = torch.cat([input_rgbd, static_features], dim=1) + x = self.layers[0](x) + x = torch.clamp(x, 0, 1) # Output layer 0 (4 channels) - # Layer 1: Concatenate static + layer0 output - x1_input = torch.cat([static_features, x0], dim=1) - x1 = self.layer1(x1_input) - x1 = F.relu(x1) + # Layer 1+: previous output (4D) + static (8D) = 12D + for i in range(1, len(self.layers)): + x_input = torch.cat([x, static_features], dim=1) + x = self.layers[i](x_input) + if i < len(self.layers) - 1: + x = F.relu(x) + else: + x = torch.clamp(x, 0, 1) # Final output [0,1] - # Layer 2: Concatenate static + layer1 output - x2_input = torch.cat([static_features, x1], dim=1) - output = self.layer2(x2_input) - - return torch.sigmoid(output) # RGBA output [0,1] + return x # RGBA output ``` **Training Configuration:** ```python # Hyperparameters -kernels = [1, 3, 5] # Per-layer kernel sizes -channels = [16, 8, 4] # Per-layer output channels +kernel_size = 3 # Uniform 3×3 kernels +num_layers = 3 # Number of CNN layers learning_rate = 1e-3 batch_size = 16 epochs = 5000 @@ -260,11 +256,14 @@ epochs = 5000 # Training loop (standard PyTorch f32) for epoch in range(epochs): for rgb_batch, depth_batch, target_batch in dataloader: - # Compute static features + # Compute static features (8D) static_feat = compute_static_features(rgb_batch, depth_batch) + # Input RGBD (4D) + input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1) + # Forward pass - output = model(static_feat) + output = model(input_rgbd, static_feat) loss = criterion(output, target_batch) # Backward pass @@ -279,9 +278,9 @@ for epoch in range(epochs): torch.save({ 'state_dict': model.state_dict(), # f32 weights 'config': { - 'kernels': [1, 3, 5], - 'channels': [16, 8, 4], - 'features': ['R', 'G', 'B', 'D', 'uv.x', 'uv.y', 'sin10_x', 'bias'] + 'kernel_size': 3, + 'num_layers': 3, + 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias'] }, 'epoch': epoch, 'loss': loss.item() @@ -342,59 +341,36 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { - Training uses f32 throughout (PyTorch standard) - Export converts to np.float16, then back to f32 for WGSL literals - **Expected discrepancy:** <0.1% MSE (acceptable) -- Validation via `validate_cnn_v2.sh` compares outputs +- Validation via HTML tool (see below) --- ## Validation Workflow -### Script: `scripts/validate_cnn_v2.sh` - -**End-to-end pipeline:** -```bash -./scripts/validate_cnn_v2.sh checkpoints/checkpoint_epoch_5000.pth -``` +### HTML Tool: `tools/cnn_v2_test/index.html` -**Steps automated:** -1. Export checkpoint → .wgsl shaders -2. Rebuild `cnn_test` tool -3. Process test images with CNN v2 -4. Display input/output results +**WebGPU-based testing tool** with layer visualization. **Usage:** -```bash -# Basic usage -./scripts/validate_cnn_v2.sh checkpoint.pth - -# Custom paths -./scripts/validate_cnn_v2.sh checkpoint.pth \ - -i my_test_images/ \ - -o results/ \ - -b build_release +1. Open `tools/cnn_v2_test/index.html` in browser +2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`) +3. Drop PNG test image +4. View results with layer inspection -# Skip rebuild (iterate on checkpoint only) -./scripts/validate_cnn_v2.sh checkpoint.pth --skip-build +**Features:** +- Live CNN inference with WebGPU +- Layer-by-layer visualization (static features + all CNN layers) +- Weight visualization (per-layer kernels) +- View modes: CNN output, original, diff (×10) +- Blend control for comparing with original -# Skip export (iterate on test images only) -./scripts/validate_cnn_v2.sh checkpoint.pth --skip-export - -# Show help -./scripts/validate_cnn_v2.sh --help +**Export weights:** +```bash +./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \ + --output-weights workspaces/main/cnn_v2_weights.bin ``` -**Options:** -- `-b, --build-dir DIR` - Build directory (default: build) -- `-w, --workspace NAME` - Workspace name (default: main) -- `-i, --images DIR` - Test images directory (default: training/validation) -- `-o, --output DIR` - Output directory (default: validation_results) -- `--skip-build` - Use existing cnn_test binary -- `--skip-export` - Use existing .wgsl shaders -- `-h, --help` - Show full usage - -**Output:** -- Input images: `<test_images_dir>/*.png` -- Output images: `<output_dir>/*_output.png` -- Opens results directory in system file browser +See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation --- @@ -460,7 +436,7 @@ fn main(@builtin(global_invocation_id) id: vec3<u32>) { ### Phase 4: Tools & Validation -- [ ] `scripts/validate_cnn_v2.sh` - End-to-end validation +- [x] HTML validation tool - WebGPU inference with layer visualization - [ ] Command-line argument parsing - [ ] Shader export orchestration - [ ] Build orchestration @@ -506,8 +482,8 @@ training/train_cnn_v2.py # Training script training/export_cnn_v2_shader.py # Shader generator training/validation/ # Test images directory -# Scripts -scripts/validate_cnn_v2.sh # End-to-end validation +# Validation +tools/cnn_v2_test/index.html # WebGPU validation tool # Documentation doc/CNN_V2.md # This file |
