summaryrefslogtreecommitdiff
path: root/cnn_v2
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v2')
-rw-r--r--cnn_v2/README.md60
-rw-r--r--cnn_v2/docs/CNN_V2.md813
-rw-r--r--cnn_v2/docs/CNN_V2_BINARY_FORMAT.md235
-rw-r--r--cnn_v2/docs/CNN_V2_DEBUG_TOOLS.md143
-rw-r--r--cnn_v2/docs/CNN_V2_WEB_TOOL.md348
-rwxr-xr-xcnn_v2/scripts/train_cnn_v2_full.sh428
-rw-r--r--cnn_v2/shaders/cnn_v2_compute.wgsl143
-rw-r--r--cnn_v2/shaders/cnn_v2_layer_0.wgsl174
-rw-r--r--cnn_v2/shaders/cnn_v2_layer_1.wgsl174
-rw-r--r--cnn_v2/shaders/cnn_v2_layer_2.wgsl156
-rw-r--r--cnn_v2/shaders/cnn_v2_layer_template.wgsl68
-rw-r--r--cnn_v2/shaders/cnn_v2_static.wgsl75
-rw-r--r--cnn_v2/src/cnn_v2_effect.cc497
-rw-r--r--cnn_v2/src/cnn_v2_effect.h89
-rw-r--r--cnn_v2/tools/cnn_v2_test/README.md251
-rw-r--r--cnn_v2/tools/cnn_v2_test/index.html2014
-rwxr-xr-xcnn_v2/training/export_cnn_v2_shader.py218
-rwxr-xr-xcnn_v2/training/export_cnn_v2_weights.py288
-rwxr-xr-xcnn_v2/training/gen_identity_weights.py175
-rwxr-xr-xcnn_v2/training/train_cnn_v2.py472
20 files changed, 6821 insertions, 0 deletions
diff --git a/cnn_v2/README.md b/cnn_v2/README.md
new file mode 100644
index 0000000..ef0cf44
--- /dev/null
+++ b/cnn_v2/README.md
@@ -0,0 +1,60 @@
+# CNN v2: Parametric Post-Processing Neural Network
+
+**Architecture:** 3-layer compute, storage buffer (~3.2 KB)
+**Features:** 7D static (RGBD + UV + sin + bias), sigmoid activation
+
+## Quick Start
+
+```bash
+./cnn_v2/scripts/train_cnn_v2_full.sh
+```
+
+## Documentation
+
+- [CNN_V2.md](docs/CNN_V2.md) - Architecture and implementation details
+- [CNN_V2_BINARY_FORMAT.md](docs/CNN_V2_BINARY_FORMAT.md) - Weight format specification
+- [CNN_V2_WEB_TOOL.md](docs/CNN_V2_WEB_TOOL.md) - Validation tool documentation
+- [CNN_V2_DEBUG_TOOLS.md](docs/CNN_V2_DEBUG_TOOLS.md) - Debugging and analysis tools
+
+## Integration
+
+- **C++:** `cnn_v2/src/cnn_v2_effect.{h,cc}`
+- **Assets:** `workspaces/main/assets.txt` (lines 47-49)
+- **Test:** `src/tests/gpu/test_demo_effects.cc` (line 93)
+
+## Directory Structure
+
+```
+cnn_v2/
+├── README.md # This file
+├── src/
+│ ├── cnn_v2_effect.h # Effect header
+│ └── cnn_v2_effect.cc # Effect implementation
+├── shaders/ # WGSL shaders (6 files)
+├── weights/ # Binary weights (3 files)
+├── training/ # Python training scripts (4 files)
+├── scripts/ # Shell scripts (train_cnn_v2_full.sh)
+├── tools/ # Validation tools (HTML)
+└── docs/ # Documentation (4 markdown files)
+```
+
+## Training Pipeline
+
+1. **Train model:** `./cnn_v2/scripts/train_cnn_v2_full.sh`
+2. **Export weights:** Automatic (binary format, ~3.2 KB)
+3. **Validate:** HTML tool at `cnn_v2/tools/cnn_v2_test/index.html`
+
+For detailed training options: `./cnn_v2/scripts/train_cnn_v2_full.sh --help`
+
+## Key Features
+
+- **Parametric static features:** 7D input (RGBD + UV + sin encoding + bias)
+- **Storage buffer architecture:** Dynamic layer count, compact binary format
+- **Sigmoid activation:** Smooth gradients, prevents training collapse
+- **Patch-based training:** Sample-efficient, focuses on salient regions
+- **Sub-10KB target:** Achieved with 3-layer model (~3.2 KB)
+
+## Next Steps
+
+- **8-bit quantization:** 2× size reduction (~1.6 KB) via quantization-aware training (QAT)
+- **CNN v3:** U-Net architecture for enhanced quality (separate directory)
diff --git a/cnn_v2/docs/CNN_V2.md b/cnn_v2/docs/CNN_V2.md
new file mode 100644
index 0000000..b7fd6f8
--- /dev/null
+++ b/cnn_v2/docs/CNN_V2.md
@@ -0,0 +1,813 @@
+# CNN v2: Parametric Static Features
+
+**Technical Design Document**
+
+---
+
+## Overview
+
+CNN v2 extends the original CNN post-processing effect with parametric static features, enabling richer spatial and frequency-domain inputs for improved visual quality.
+
+**Key improvements over v1:**
+- 7D static feature input (vs 4D RGB)
+- Multi-frequency position encoding (NeRF-style)
+- Configurable mip-level for p0-p3 parametric features (0-3)
+- Per-layer configurable kernel sizes (1×1, 3×3, 5×5)
+- Variable channel counts per layer
+- Float16 weight storage (~3.2 KB for 3-layer model)
+- Bias integrated as static feature dimension
+- Storage buffer architecture (dynamic layer count)
+- Binary weight format v2 for runtime loading
+- Sigmoid activation for layer 0 and final layer (smooth [0,1] mapping)
+
+**Status:** ✅ Complete. Sigmoid activation, stable training, validation tools operational.
+
+**Breaking Change:**
+- Models trained with `clamp()` incompatible. Retrain required.
+
+**TODO:**
+- 8-bit quantization with QAT for 2× size reduction (~1.6 KB)
+
+---
+
+## Architecture
+
+### Pipeline Overview
+
+```
+Input RGBD → Static Features Compute → CNN Layers → Output RGBA
+ └─ computed once/frame ─┘ └─ multi-pass ─┘
+```
+
+**Detailed Data Flow:**
+
+```
+ ┌─────────────────────────────────────────┐
+ │ Static Features (computed once) │
+ │ 8D: p0,p1,p2,p3,uv_x,uv_y,sin10x,bias │
+ └──────────────┬──────────────────────────┘
+ │
+ │ 8D (broadcast to all layers)
+ ├───────────────────────────┐
+ │ │
+ ┌──────────────┐ │ │
+ │ Input RGBD │──────────────┤ │
+ │ 4D │ 4D │ │
+ └──────────────┘ │ │
+ ▼ │
+ ┌────────────┐ │
+ │ Layer 0 │ (12D input) │
+ │ (CNN) │ = 4D + 8D │
+ │ 12D → 4D │ │
+ └─────┬──────┘ │
+ │ 4D output │
+ │ │
+ ├───────────────────────────┘
+ │ │
+ ▼ │
+ ┌────────────┐ │
+ │ Layer 1 │ (12D input) │
+ │ (CNN) │ = 4D + 8D │
+ │ 12D → 4D │ │
+ └─────┬──────┘ │
+ │ 4D output │
+ │ │
+ ├───────────────────────────┘
+ ▼ │
+ ... │
+ │ │
+ ▼ │
+ ┌────────────┐ │
+ │ Layer N │ (12D input) │
+ │ (output) │◄──────────────────┘
+ │ 12D → 4D │
+ └─────┬──────┘
+ │ 4D (RGBA)
+ ▼
+ Output
+```
+
+**Key Points:**
+- Static features computed once, broadcast to all CNN layers
+- Each layer: previous 4D output + 8D static → 12D input → 4D output
+- Ping-pong buffering between layers
+- Layer 0 special case: uses input RGBD instead of previous layer output
+
+**Static Features Texture:**
+- Name: `static_features`
+- Format: `texture_storage_2d<rgba32uint, write>` (4×u32)
+- Data: 8 float16 values packed via `pack2x16float()`
+- Computed once per frame, read by all CNN layers
+- Lifetime: Entire frame (all CNN layer passes)
+
+**CNN Layers:**
+- Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels
+- Layer 1+: previous output (4D) + static (8D) = 12D → 4 channels
+- All layers: uniform 12D input, 4D output (ping-pong buffer)
+- Storage: `texture_storage_2d<rgba32uint>` (4 channels as 2×f16 pairs)
+
+**Activation Functions:**
+- Layer 0 & final layer: `sigmoid(x)` for smooth [0,1] mapping
+- Middle layers: `ReLU` (max(0, x))
+- Rationale: Sigmoid prevents gradient blocking at boundaries, enabling better convergence
+- Breaking change: Models trained with `clamp(x, 0, 1)` are incompatible, retrain required
+
+---
+
+## Static Features (7D + 1 bias)
+
+### Feature Layout
+
+**8 float16 values per pixel:**
+
+```wgsl
+// Slot 0-3: Parametric features (p0, p1, p2, p3)
+// Sampled from configurable mip level (0=original, 1=half, 2=quarter, 3=eighth)
+// Training sets mip_level via --mip-level flag, stored in binary format v2
+let p0 = ...; // RGB.r from selected mip level
+let p1 = ...; // RGB.g from selected mip level
+let p2 = ...; // RGB.b from selected mip level
+let p3 = ...; // Depth or RGB channel from mip level
+
+// Slot 4-5: UV coordinates (normalized screen space)
+let uv_x = coord.x / resolution.x; // Horizontal position [0,1]
+let uv_y = coord.y / resolution.y; // Vertical position [0,1]
+
+// Slot 6: Multi-frequency position encoding
+let sin20_y = sin(20.0 * uv_y); // Periodic feature (frequency=20, vertical)
+
+// Slot 7: Bias dimension (always 1.0)
+let bias = 1.0; // Learned bias per output channel
+
+// Packed storage: [p0, p1, p2, p3, uv.x, uv.y, sin(20*uv.y), 1.0]
+```
+
+### Input Channel Mapping
+
+**Weight tensor layout (12 input channels per layer):**
+
+| Input Channel | Feature | Description |
+|--------------|---------|-------------|
+| 0-3 | Previous layer output | 4D RGBA from prior CNN layer (or input RGBD for Layer 0) |
+| 4-11 | Static features | 8D: p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias |
+
+**Static feature channel details:**
+- Channel 4 → p0 (RGB.r from mip level)
+- Channel 5 → p1 (RGB.g from mip level)
+- Channel 6 → p2 (RGB.b from mip level)
+- Channel 7 → p3 (depth or RGB channel from mip level)
+- Channel 8 → p4 (uv_x: normalized horizontal position)
+- Channel 9 → p5 (uv_y: normalized vertical position)
+- Channel 10 → p6 (sin(20*uv_y): periodic encoding)
+- Channel 11 → p7 (bias: constant 1.0)
+
+**Note:** When generating identity weights, p4-p7 correspond to input channels 8-11, not 4-7.
+
+### Feature Rationale
+
+| Feature | Dimension | Purpose | Priority |
+|---------|-----------|---------|----------|
+| p0-p3 | 4D | Parametric auxiliary features (mips, gradients, etc.) | Essential |
+| UV coords | 2D | Spatial position awareness | Essential |
+| sin(20\*uv.y) | 1D | Periodic position encoding (vertical) | Medium |
+| Bias | 1D | Learned bias (standard NN) | Essential |
+
+**Note:** Input image RGBD (mip 0) fed only to Layer 0. Subsequent layers see static features + previous layer output.
+
+**Why bias as static feature:**
+- Simpler shader code (single weight array)
+- Standard NN formulation: y = Wx (x includes bias term)
+- Saves 56-112 bytes (no separate bias buffer)
+- 7 features sufficient for initial implementation
+
+### Future Feature Extensions
+
+**Option: Additional encodings:**
+- `sin(40*uv.y)` - Higher frequency encoding
+- `gray_mip1` - Multi-scale luminance
+- `dx`, `dy` - Sobel gradients
+- `variance` - Local texture measure
+- `laplacian` - Edge detection
+
+**Option: uint8 packing (16+ features):**
+```wgsl
+// texture_storage_2d<rgba8unorm> stores 16 uint8 values
+// Trade precision for feature count
+// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y,
+// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, var, bias]
+```
+Requires quantization-aware training.
+
+---
+
+## Layer Structure
+
+### Example 3-Layer Network
+
+```
+Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels (3×3 kernel)
+Layer 1: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel)
+Layer 2: previous (4D) + static (8D) = 12D → 4 channels (3×3 kernel, output RGBA)
+```
+
+**Output:** 4 channels (RGBA). Training targets preserve alpha from target images.
+
+### Weight Calculations
+
+**Per-layer weights (uniform 12D→4D, 3×3 kernels):**
+```
+Layer 0: 12 × 3 × 3 × 4 = 432 weights
+Layer 1: 12 × 3 × 3 × 4 = 432 weights
+Layer 2: 12 × 3 × 3 × 4 = 432 weights
+Total: 1296 weights
+```
+
+**Storage sizes:**
+- f32: 1296 × 4 = 5,184 bytes (~5.1 KB)
+- f16: 1296 × 2 = 2,592 bytes (~2.5 KB) ✓ **recommended**
+
+**Comparison to v1:**
+- v1: ~800 weights (3.2 KB f32)
+- v2: ~1296 weights (2.5 KB f16)
+- **Uniform architecture, smaller than v1 f32**
+
+### Kernel Size Guidelines
+
+**1×1 kernel (pointwise):**
+- No spatial context, channel mixing only
+- Weights: `12 × 4 = 48` per layer
+- Use for: Fast inference, channel remapping
+
+**3×3 kernel (standard conv):**
+- Local spatial context (recommended)
+- Weights: `12 × 9 × 4 = 432` per layer
+- Use for: Most layers (balanced quality/size)
+
+**5×5 kernel (large receptive field):**
+- Wide spatial context
+- Weights: `12 × 25 × 4 = 1200` per layer
+- Use for: Output layer, fine detail enhancement
+
+### Channel Storage (4×f16 per texel)
+
+```wgsl
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+
+fn unpack_channels(coord: vec2<i32>) -> vec4<f32> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x); // [ch0, ch1]
+ let v1 = unpack2x16float(packed.y); // [ch2, ch3]
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
+}
+
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2(values.x, values.y)),
+ pack2x16float(vec2(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
+ );
+}
+```
+
+---
+
+## Training Workflow
+
+### Script: `training/train_cnn_v2.py`
+
+**Static Feature Extraction:**
+
+```python
+def compute_static_features(rgb, depth, mip_level=0):
+ """Generate parametric features (8D: p0-p3 + spatial).
+
+ Args:
+ mip_level: 0=original, 1=half res, 2=quarter res, 3=eighth res
+ """
+ h, w = rgb.shape[:2]
+
+ # Generate mip level for p0-p3 (downsample then upsample)
+ if mip_level > 0:
+ mip_rgb = rgb.copy()
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrDown(mip_rgb)
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrUp(mip_rgb)
+ if mip_rgb.shape[:2] != (h, w):
+ mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
+ else:
+ mip_rgb = rgb
+
+ # Parametric features from mip level
+ p0, p1, p2, p3 = mip_rgb[..., 0], mip_rgb[..., 1], mip_rgb[..., 2], depth
+
+ # UV coordinates (normalized)
+ uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0)
+ uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1)
+
+ # Multi-frequency position encoding
+ sin10_x = np.sin(10.0 * uv_x)
+
+ # Bias dimension (always 1.0)
+ bias = np.ones_like(p0)
+
+ # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
+ return np.stack([p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias], axis=-1)
+```
+
+**Network Definition:**
+
+```python
+class CNNv2(nn.Module):
+ def __init__(self, kernel_sizes, num_layers=3):
+ super().__init__()
+ if isinstance(kernel_sizes, int):
+ kernel_sizes = [kernel_sizes] * num_layers
+ self.kernel_sizes = kernel_sizes
+ self.layers = nn.ModuleList()
+
+ # All layers: 12D input (4 prev + 8 static) → 4D output
+ for kernel_size in kernel_sizes:
+ self.layers.append(
+ nn.Conv2d(12, 4, kernel_size=kernel_size,
+ padding=kernel_size//2, bias=False)
+ )
+
+ def forward(self, input_rgbd, static_features):
+ # Layer 0: input RGBD (4D) + static (8D) = 12D
+ x = torch.cat([input_rgbd, static_features], dim=1)
+ x = self.layers[0](x)
+ x = torch.sigmoid(x) # Soft [0,1] for layer 0
+
+ # Layer 1+: previous output (4D) + static (8D) = 12D
+ for i in range(1, len(self.layers)):
+ x_input = torch.cat([x, static_features], dim=1)
+ x = self.layers[i](x_input)
+ if i < len(self.layers) - 1:
+ x = F.relu(x)
+ else:
+ x = torch.sigmoid(x) # Soft [0,1] for final layer
+
+ return x # RGBA output
+```
+
+**Training Configuration:**
+
+```python
+# Hyperparameters
+kernel_sizes = [3, 3, 3] # Per-layer kernel sizes (e.g., [1,3,5])
+num_layers = 3 # Number of CNN layers
+mip_level = 0 # Mip level for p0-p3: 0=orig, 1=half, 2=quarter, 3=eighth
+grayscale_loss = False # Compute loss on grayscale (Y) instead of RGBA
+learning_rate = 1e-3
+batch_size = 16
+epochs = 5000
+
+# Dataset: Input RGB, Target RGBA (preserves alpha channel from image)
+# Model outputs RGBA, loss compares all 4 channels (or grayscale if --grayscale-loss)
+
+# Training loop (standard PyTorch f32)
+for epoch in range(epochs):
+ for rgb_batch, depth_batch, target_batch in dataloader:
+ # Compute static features (8D) with mip level
+ static_feat = compute_static_features(rgb_batch, depth_batch, mip_level)
+
+ # Input RGBD (4D)
+ input_rgbd = torch.cat([rgb_batch, depth_batch.unsqueeze(1)], dim=1)
+
+ # Forward pass
+ output = model(input_rgbd, static_feat)
+
+ # Loss computation (grayscale or RGBA)
+ if grayscale_loss:
+ # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B
+ output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3]
+ target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
+ loss = criterion(output_gray, target_gray)
+ else:
+ loss = criterion(output, target_batch)
+
+ # Backward pass
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+```
+
+**Checkpoint Format:**
+
+```python
+torch.save({
+ 'state_dict': model.state_dict(), # f32 weights
+ 'config': {
+ 'kernel_sizes': [3, 3, 3], # Per-layer kernel sizes
+ 'num_layers': 3,
+ 'mip_level': 0, # Mip level used for p0-p3
+ 'grayscale_loss': False, # Whether grayscale loss was used
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin10_x', 'bias']
+ },
+ 'epoch': epoch,
+ 'loss': loss.item()
+}, f'checkpoints/checkpoint_epoch_{epoch}.pth')
+```
+
+---
+
+## Export Workflow
+
+### Script: `training/export_cnn_v2_shader.py`
+
+**Process:**
+1. Load checkpoint (f32 PyTorch weights)
+2. Extract layer configs (kernels, channels)
+3. Quantize weights to float16: `weights_f16 = weights_f32.astype(np.float16)`
+4. Generate WGSL shader per layer
+5. Write to `workspaces/<workspace>/shaders/cnn_v2/cnn_v2_*.wgsl`
+
+**Example Generated Shader:**
+
+```wgsl
+// cnn_v2_layer_0.wgsl - Auto-generated from checkpoint_epoch_5000.pth
+
+const KERNEL_SIZE: u32 = 1u;
+const IN_CHANNELS: u32 = 8u; // 7 features + bias
+const OUT_CHANNELS: u32 = 16u;
+
+// Weights quantized to float16 (stored as f32 in shader)
+const weights: array<f32, 128> = array(
+ 0.123047, -0.089844, 0.234375, 0.456055, ...
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var output_texture: texture_storage_2d<rgba32uint, write>;
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ // Load static features (8D)
+ let static_feat = get_static_features(vec2<i32>(id.xy));
+
+ // Convolution (1×1 kernel = pointwise)
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+ for (var k: u32 = 0u; k < IN_CHANNELS; k++) {
+ sum += weights[c * IN_CHANNELS + k] * static_feat[k];
+ }
+ output[c] = max(0.0, sum); // ReLU activation
+ }
+
+ // Pack and store (8×f16 per texel)
+ textureStore(output_texture, vec2<i32>(id.xy), pack_f16x8(output));
+}
+```
+
+**Float16 Quantization:**
+- Training uses f32 throughout (PyTorch standard)
+- Export converts to np.float16, then back to f32 for WGSL literals
+- **Expected discrepancy:** <0.1% MSE (acceptable)
+- Validation via HTML tool (see below)
+
+---
+
+## Validation Workflow
+
+### HTML Tool: `tools/cnn_v2_test/index.html`
+
+**WebGPU-based testing tool** with layer visualization.
+
+**Usage:**
+1. Open `tools/cnn_v2_test/index.html` in browser
+2. Drop `.bin` weights file (from `export_cnn_v2_weights.py`)
+3. Drop PNG test image
+4. View results with layer inspection
+
+**Features:**
+- Live CNN inference with WebGPU
+- Layer-by-layer visualization (static features + all CNN layers)
+- Weight visualization (per-layer kernels)
+- View modes: CNN output, original, diff (×10)
+- Blend control for comparing with original
+
+**Export weights:**
+```bash
+./training/export_cnn_v2_weights.py checkpoints/checkpoint_epoch_100.pth \
+ --output-weights workspaces/main/cnn_v2_weights.bin
+```
+
+See `doc/CNN_V2_WEB_TOOL.md` for detailed documentation
+
+---
+
+## Implementation Checklist
+
+### Phase 1: Shaders (Core Infrastructure)
+
+- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl` - Static features compute
+ - [ ] RGBD sampling from framebuffer
+ - [ ] UV coordinate calculation
+ - [ ] sin(10\*uv.x) computation
+ - [ ] Bias dimension (constant 1.0)
+ - [ ] Float16 packing via `pack2x16float()`
+ - [ ] Output to `texture_storage_2d<rgba32uint>`
+
+- [ ] `workspaces/main/shaders/cnn_v2/cnn_v2_layer_template.wgsl` - Layer template
+ - [ ] Static features unpacking
+ - [ ] Previous layer unpacking (8×f16)
+ - [ ] Convolution implementation (1×1, 3×3, 5×5)
+ - [ ] ReLU activation
+ - [ ] Output packing (8×f16)
+ - [ ] Proper padding handling
+
+### Phase 2: C++ Effect Class
+
+- [ ] `src/effects/cnn_v2_effect.h` - Header
+ - [ ] Class declaration inheriting from `PostProcessEffect`
+ - [ ] Static features texture member
+ - [ ] Layer textures vector
+ - [ ] Pipeline and bind group members
+
+- [ ] `src/effects/cnn_v2_effect.cc` - Implementation
+ - [ ] Constructor: Load shaders, create textures
+ - [ ] `init()`: Create pipelines, bind groups
+ - [ ] `render()`: Multi-pass execution
+ - [ ] Pass 0: Compute static features
+ - [ ] Pass 1-N: CNN layers
+ - [ ] Final: Composite to output
+ - [ ] Proper resource cleanup
+
+- [ ] Integration
+ - [ ] Add to `src/gpu/demo_effects.h` includes
+ - [ ] Add `cnn_v2_effect.cc` to `CMakeLists.txt` (headless + normal)
+ - [ ] Add shaders to `workspaces/main/assets.txt`
+ - [ ] Add to `src/tests/gpu/test_demo_effects.cc`
+
+### Phase 3: Training Pipeline
+
+- [ ] `training/train_cnn_v2.py` - Training script
+ - [ ] Static feature extraction function
+ - [ ] CNNv2 PyTorch model class
+ - [ ] Patch-based dataloader
+ - [ ] Training loop with checkpointing
+ - [ ] Command-line argument parsing
+ - [ ] Inference mode (ground truth generation)
+
+- [ ] `training/export_cnn_v2_shader.py` - Export script
+ - [ ] Checkpoint loading
+ - [ ] Weight extraction and f16 quantization
+ - [ ] Per-layer WGSL generation
+ - [ ] File output to workspace shaders/
+ - [ ] Metadata preservation
+
+### Phase 4: Tools & Validation
+
+- [x] HTML validation tool - WebGPU inference with layer visualization
+ - [ ] Command-line argument parsing
+ - [ ] Shader export orchestration
+ - [ ] Build orchestration
+ - [ ] Batch image processing
+ - [ ] Results display
+
+- [ ] `src/tools/cnn_test_main.cc` - Tool updates
+ - [ ] Add `--cnn-version v2` flag
+ - [ ] CNNv2Effect instantiation path
+ - [ ] Static features pass execution
+ - [ ] Multi-layer processing
+
+### Phase 5: Documentation
+
+- [ ] `doc/HOWTO.md` - Usage guide
+ - [ ] Training section (CNN v2)
+ - [ ] Export section
+ - [ ] Validation section
+ - [ ] Examples
+
+- [ ] `README.md` - Project overview update
+ - [ ] Mention CNN v2 capability
+
+---
+
+## File Structure
+
+### New Files
+
+```
+# Shaders (generated by export script)
+workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl # Static features compute
+workspaces/main/shaders/cnn_v2/cnn_v2_layer_0.wgsl # Input layer (generated)
+workspaces/main/shaders/cnn_v2/cnn_v2_layer_1.wgsl # Inner layer (generated)
+workspaces/main/shaders/cnn_v2/cnn_v2_layer_2.wgsl # Output layer (generated)
+
+# C++ implementation
+src/effects/cnn_v2_effect.h # Effect class header
+src/effects/cnn_v2_effect.cc # Effect implementation
+
+# Python training/export
+training/train_cnn_v2.py # Training script
+training/export_cnn_v2_shader.py # Shader generator
+training/validation/ # Test images directory
+
+# Validation
+tools/cnn_v2_test/index.html # WebGPU validation tool
+
+# Documentation
+doc/CNN_V2.md # This file
+```
+
+### Modified Files
+
+```
+src/gpu/demo_effects.h # Add CNNv2Effect include
+CMakeLists.txt # Add cnn_v2_effect.cc
+workspaces/main/assets.txt # Add cnn_v2 shaders
+workspaces/main/timeline.seq # Optional: add CNNv2Effect
+src/tests/gpu/test_demo_effects.cc # Add CNNv2 test case
+src/tools/cnn_test_main.cc # Add --cnn-version v2
+doc/HOWTO.md # Add CNN v2 sections
+TODO.md # Add CNN v2 task
+```
+
+### Unchanged (v1 Preserved)
+
+```
+training/train_cnn.py # Original training
+src/effects/cnn_effect.* # Original effect
+workspaces/main/shaders/cnn_*.wgsl # Original v1 shaders
+```
+
+---
+
+## Performance Characteristics
+
+### Static Features Compute
+- **Cost:** ~0.1ms @ 1080p
+- **Frequency:** Once per frame
+- **Operations:** sin(), texture sampling, packing
+
+### CNN Layers (Example 3-layer)
+- **Layer0 (1×1, 8→16):** ~0.3ms
+- **Layer1 (3×3, 23→8):** ~0.8ms
+- **Layer2 (5×5, 15→4):** ~1.2ms
+- **Total:** ~2.4ms @ 1080p
+
+### Memory Usage
+- Static features: 1920×1080×8×2 = 33 MB (f16)
+- Layer buffers: 1920×1080×16×2 = 66 MB (max 16 channels)
+- Weights: ~6.4 KB (f16, in shader code)
+- **Total GPU memory:** ~100 MB
+
+---
+
+## Size Budget
+
+### CNN v1 vs v2
+
+| Metric | v1 | v2 | Delta |
+|--------|----|----|-------|
+| Weights (count) | 800 | 3268 | +2468 |
+| Storage (f32) | 3.2 KB | 13.1 KB | +9.9 KB |
+| Storage (f16) | N/A | 6.5 KB | +6.5 KB |
+| Shader code | ~500 lines | ~800 lines | +300 lines |
+
+### Mitigation Strategies
+
+**Reduce channels:**
+- [16,8,4] → [8,4,4] saves ~50% weights
+- [16,8,4] → [4,4,4] saves ~60% weights
+
+**Smaller kernels:**
+- [1,3,5] → [1,3,3] saves ~30% weights
+- [1,3,5] → [1,1,3] saves ~50% weights
+
+**Quantization:**
+- int8 weights: saves 75% (requires QAT training)
+- 4-bit weights: saves 87.5% (extreme, needs research)
+
+**Target:** Keep CNN v2 under 10 KB for 64k demo constraint
+
+---
+
+## Future Extensions
+
+### Flexible Feature Layout (Binary Format v3)
+
+**TODO:** Support arbitrary feature vector layouts and ordering in binary format.
+
+**Current Limitation:**
+- Feature layout hardcoded: `[p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]`
+- Shader must match training script exactly
+- Experimentation requires shader recompilation
+
+**Proposed Enhancement:**
+- Add feature descriptor to binary format header
+- Specify feature types, sources, and ordering
+- Runtime shader generation or dynamic feature indexing
+- Examples: `[R, G, B, dx, dy, uv_x, bias]` or `[mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]`
+
+**Benefits:**
+- Training experiments without C++/shader changes
+- A/B test different feature combinations
+- Single binary format, multiple architectures
+- Faster iteration on feature engineering
+
+**Implementation Options:**
+1. **Static approach:** Generate shader code from descriptor at load time
+2. **Dynamic approach:** Array-based indexing with feature map uniform
+3. **Hybrid:** Precompile common layouts, fallback to dynamic
+
+See `doc/CNN_V2_BINARY_FORMAT.md` for proposed descriptor format.
+
+---
+
+### More Features (uint8 Packing)
+
+```wgsl
+// 16 uint8 features per texel (texture_storage_2d<rgba8unorm>)
+// [R, G, B, D, uv.x, uv.y, sin10.x, sin10.y,
+// sin20.x, sin20.y, dx, dy, gray_mip1, gray_mip2, variance, bias]
+```
+- Trade precision for quantity
+- Requires quantization-aware training
+
+### Temporal Features
+
+- Previous frame RGBA (motion awareness)
+- Optical flow vectors
+- Requires multi-frame buffer
+
+### Learned Position Encodings
+
+- Replace hand-crafted sin(10\*uv) with learned embeddings
+- Requires separate embedding network
+- Similar to NeRF position encoding
+
+### Dynamic Architecture
+
+- Runtime kernel size selection based on scene
+- Conditional layer execution (skip connections)
+- Layer pruning for performance
+
+---
+
+## References
+
+- **v1 Implementation:** `src/effects/cnn_effect.*`
+- **Training Guide:** `doc/HOWTO.md` (CNN Training section)
+- **Test Tool:** `doc/CNN_TEST_TOOL.md`
+- **Shader System:** `doc/SEQUENCE.md`
+- **Size Measurement:** `doc/SIZE_MEASUREMENT.md`
+
+---
+
+## Appendix: Design Decisions
+
+### Why Bias as Static Feature?
+
+**Alternatives considered:**
+1. Separate bias array per layer (Option B)
+2. Bias as static feature = 1.0 (Option A, chosen)
+
+**Decision rationale:**
+- Simpler shader code (fewer bindings)
+- Standard NN formulation (augmented input)
+- Saves 56-112 bytes per model
+- 7 features sufficient for v1 implementation
+- Can extend to uint8 packing if >7 features needed
+
+### Why Float16 for Weights?
+
+**Alternatives considered:**
+1. Keep f32 (larger, more accurate)
+2. Use f16 (smaller, GPU-native)
+3. Use int8 (smallest, needs QAT)
+
+**Decision rationale:**
+- f16 saves 50% vs f32 (critical for 64k target)
+- GPU-native support (pack2x16float in WGSL)
+- <0.1% accuracy loss (acceptable)
+- Simpler than int8 quantization
+
+### Why Multi-Frequency Position Encoding?
+
+**Inspiration:** NeRF (Neural Radiance Fields)
+
+**Benefits:**
+- Helps network learn high-frequency details
+- Better than raw UV coordinates
+- Small footprint (1D per frequency)
+
+**Future:** Add sin(20\*uv), sin(40\*uv) if >7 features available
+
+---
+
+## Related Documentation
+
+- `doc/CNN_V2_BINARY_FORMAT.md` - Binary weight file specification (.bin format)
+- `doc/CNN_V2_WEB_TOOL.md` - WebGPU testing tool with layer visualization
+- `doc/CNN_TEST_TOOL.md` - C++ offline validation tool (deprecated)
+- `doc/HOWTO.md` - Training and validation workflows
+
+---
+
+**Document Version:** 1.0
+**Last Updated:** 2026-02-12
+**Status:** Design approved, ready for implementation
diff --git a/cnn_v2/docs/CNN_V2_BINARY_FORMAT.md b/cnn_v2/docs/CNN_V2_BINARY_FORMAT.md
new file mode 100644
index 0000000..59c859d
--- /dev/null
+++ b/cnn_v2/docs/CNN_V2_BINARY_FORMAT.md
@@ -0,0 +1,235 @@
+# CNN v2 Binary Weight Format Specification
+
+Binary format for storing trained CNN v2 weights with static feature architecture.
+
+**File Extension:** `.bin`
+**Byte Order:** Little-endian
+**Version:** 2.0 (supports mip-level for parametric features)
+**Backward Compatible:** Version 1.0 files supported (mip_level=0)
+
+---
+
+## File Structure
+
+**Version 2 (current):**
+```
+┌─────────────────────┐
+│ Header (20 bytes) │
+├─────────────────────┤
+│ Layer Info │
+│ (20 bytes × N) │
+├─────────────────────┤
+│ Weight Data │
+│ (variable size) │
+└─────────────────────┘
+```
+
+**Version 1 (legacy):**
+```
+┌─────────────────────┐
+│ Header (16 bytes) │
+├─────────────────────┤
+│ Layer Info │
+│ (20 bytes × N) │
+├─────────────────────┤
+│ Weight Data │
+│ (variable size) │
+└─────────────────────┘
+```
+
+---
+
+## Header
+
+**Version 2 (20 bytes):**
+
+| Offset | Type | Field | Description |
+|--------|------|----------------|--------------------------------------|
+| 0x00 | u32 | magic | Magic number: `0x32_4E_4E_43` ("CNN2") |
+| 0x04 | u32 | version | Format version (2 for current) |
+| 0x08 | u32 | num_layers | Number of CNN layers (excludes static features) |
+| 0x0C | u32 | total_weights | Total f16 weight count across all layers |
+| 0x10 | u32 | mip_level | Mip level for p0-p3 features (0=original, 1=half, 2=quarter, 3=eighth) |
+
+**Version 1 (16 bytes) - Legacy:**
+
+| Offset | Type | Field | Description |
+|--------|------|----------------|--------------------------------------|
+| 0x00 | u32 | magic | Magic number: `0x32_4E_4E_43` ("CNN2") |
+| 0x04 | u32 | version | Format version (1) |
+| 0x08 | u32 | num_layers | Number of CNN layers |
+| 0x0C | u32 | total_weights | Total f16 weight count |
+
+**Note:** Loaders should check version field and handle both formats. Version 1 files treated as mip_level=0.
+
+---
+
+## Layer Info (20 bytes per layer)
+
+Repeated `num_layers` times:
+- **Version 2:** Starting at offset 0x14 (20 bytes)
+- **Version 1:** Starting at offset 0x10 (16 bytes)
+
+| Offset | Type | Field | Description |
+|-------------|------|----------------|--------------------------------------|
+| 0x00 | u32 | kernel_size | Convolution kernel dimension (3, 5, 7, etc.) |
+| 0x04 | u32 | in_channels | Input channel count (includes 8 static features for Layer 1) |
+| 0x08 | u32 | out_channels | Output channel count (max 8) |
+| 0x0C | u32 | weight_offset | Weight array start index (f16 units, relative to weight data section) |
+| 0x10 | u32 | weight_count | Number of f16 weights for this layer |
+
+**Layer Order:** Sequential (Layer 1, Layer 2, Layer 3, ...)
+
+---
+
+## Weight Data (variable size)
+
+Starts at offset:
+- **Version 2:** `20 + (num_layers × 20)`
+- **Version 1:** `16 + (num_layers × 20)`
+
+**Format:** Packed f16 pairs stored as u32
+**Packing:** `u32 = (f16_hi << 16) | f16_lo`
+**Storage:** Sequential by layer, then by output channel, input channel, spatial position
+
+**Weight Indexing:**
+```
+weight_idx = output_ch × (in_channels × kernel_size²) +
+ input_ch × kernel_size² +
+ (ky × kernel_size + kx)
+```
+
+Where:
+- `output_ch` ∈ [0, out_channels)
+- `input_ch` ∈ [0, in_channels)
+- `ky`, `kx` ∈ [0, kernel_size)
+
+**Unpacking f16 from u32:**
+```c
+uint32_t packed = weights_buffer[weight_idx / 2];
+uint16_t f16_bits = (weight_idx % 2 == 0) ? (packed & 0xFFFF) : (packed >> 16);
+```
+
+---
+
+## Example: 3-Layer Network (Version 2)
+
+**Configuration:**
+- Mip level: 0 (original resolution)
+- Layer 0: 12→4, kernel 3×3 (432 weights)
+- Layer 1: 12→4, kernel 3×3 (432 weights)
+- Layer 2: 12→4, kernel 3×3 (432 weights)
+
+**File Layout:**
+```
+Offset Size Content
+------ ---- -------
+0x00 20 Header (magic, version=2, layers=3, weights=1296, mip_level=0)
+0x14 20 Layer 0 info (kernel=3, in=12, out=4, offset=0, count=432)
+0x28 20 Layer 1 info (kernel=3, in=12, out=4, offset=432, count=432)
+0x3C 20 Layer 2 info (kernel=3, in=12, out=4, offset=864, count=432)
+0x50 2592 Weight data (1296 u32 packed f16 pairs)
+ ----
+Total: 2672 bytes (~2.6 KB)
+```
+
+---
+
+## Static Features
+
+Not stored in .bin file (computed at runtime):
+
+**8D Input Features:**
+1. **p0** - Parametric feature 0 (from mip level)
+2. **p1** - Parametric feature 1 (from mip level)
+3. **p2** - Parametric feature 2 (from mip level)
+4. **p3** - Parametric feature 3 (depth or from mip level)
+5. **UV_X** - Normalized x coordinate [0,1]
+6. **UV_Y** - Normalized y coordinate [0,1]
+7. **sin(20 × UV_Y)** - Spatial frequency encoding (vertical, frequency=20)
+8. **1.0** - Bias term
+
+**Mip Level Usage (p0-p3):**
+- `mip_level=0`: RGB from original resolution (mip 0)
+- `mip_level=1`: RGB from half resolution (mip 1), upsampled
+- `mip_level=2`: RGB from quarter resolution (mip 2), upsampled
+- `mip_level=3`: RGB from eighth resolution (mip 3), upsampled
+
+**Layer 0** receives input RGBD (4D) + static features (8D) = 12D input → 4D output.
+**Layer 1+** receive previous layer output (4D) + static features (8D) = 12D input → 4D output.
+
+---
+
+## Validation
+
+**Magic Check:**
+```c
+uint32_t magic;
+fread(&magic, 4, 1, fp);
+if (magic != 0x32_4E_4E_43) { error("Invalid CNN v2 file"); }
+```
+
+**Version Check:**
+```c
+uint32_t version;
+fread(&version, 4, 1, fp);
+if (version != 1 && version != 2) { error("Unsupported version"); }
+uint32_t header_size = (version == 1) ? 16 : 20;
+```
+
+**Size Check:**
+```c
+expected_size = header_size + (num_layers × 20) + (total_weights × 2);
+if (file_size != expected_size) { error("Size mismatch"); }
+```
+
+**Weight Offset Sanity:**
+```c
+// Each layer's offset should match cumulative count
+uint32_t cumulative = 0;
+for (int i = 0; i < num_layers; i++) {
+ if (layers[i].weight_offset != cumulative) { error("Invalid offset"); }
+ cumulative += layers[i].weight_count;
+}
+if (cumulative != total_weights) { error("Total mismatch"); }
+```
+
+---
+
+## Future Extensions
+
+**TODO: Flexible Feature Layout**
+
+Current limitation: Feature vector layout is hardcoded as `[p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]`.
+
+Proposed enhancement for version 3:
+- Add feature descriptor section to header
+- Specify feature count, types, and ordering
+- Support arbitrary 7D feature combinations (e.g., `[R, G, B, dx, dy, uv_x, bias]`)
+- Allow runtime shader generation based on descriptor
+- Enable experimentation without recompiling shaders
+
+Example descriptor format:
+```
+struct FeatureDescriptor {
+ u32 feature_count; // Number of features (typically 7-8)
+ u32 feature_types[8]; // Type enum per feature
+ u32 feature_sources[8]; // Source enum (mip0, mip1, gradient, etc.)
+ u32 reserved[8]; // Future use
+}
+```
+
+Benefits:
+- Training can experiment with different feature combinations
+- No shader recompilation needed
+- Single binary format supports multiple architectures
+- Easier A/B testing of feature effectiveness
+
+---
+
+## Related Files
+
+- `training/export_cnn_v2_weights.py` - Binary export tool
+- `src/effects/cnn_v2_effect.cc` - C++ loader
+- `tools/cnn_v2_test/index.html` - WebGPU validator
+- `doc/CNN_V2.md` - Architecture design
diff --git a/cnn_v2/docs/CNN_V2_DEBUG_TOOLS.md b/cnn_v2/docs/CNN_V2_DEBUG_TOOLS.md
new file mode 100644
index 0000000..8d1289a
--- /dev/null
+++ b/cnn_v2/docs/CNN_V2_DEBUG_TOOLS.md
@@ -0,0 +1,143 @@
+# CNN v2 Debugging Tools
+
+Tools for investigating CNN v2 mismatch between HTML tool and cnn_test.
+
+---
+
+## Identity Weight Generator
+
+**Purpose:** Generate trivial .bin files with identity passthrough for debugging.
+
+**Script:** `training/gen_identity_weights.py`
+
+**Usage:**
+```bash
+# 1×1 identity (default)
+./training/gen_identity_weights.py workspaces/main/weights/cnn_v2_identity.bin
+
+# 3×3 identity
+./training/gen_identity_weights.py workspaces/main/weights/cnn_v2_identity_3x3.bin --kernel-size 3
+
+# Mix mode: 50-50 blend (0.5*p0+0.5*p4, etc)
+./training/gen_identity_weights.py output.bin --mix
+
+# Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3
+./training/gen_identity_weights.py output.bin --p47
+
+# Custom mip level
+./training/gen_identity_weights.py output.bin --kernel-size 1 --mip-level 2
+```
+
+**Output:**
+- Single layer, 12D→4D (4 input channels + 8 static features)
+- Identity mode: Output Ch{0,1,2,3} = Input Ch{0,1,2,3}
+- Mix mode (--mix): Output Ch{i} = 0.5*Input Ch{i} + 0.5*Input Ch{i+4} (50-50 blend, avoids overflow)
+- Static mode (--p47): Output Ch{i} = Input Ch{i+4} (static features only, visualizes p4-p7)
+- Minimal file size (~136 bytes for 1×1, ~904 bytes for 3×3)
+
+**Validation:**
+Load in HTML tool or cnn_test - output should match input (RGB only, ignoring static features).
+
+---
+
+## Composited Layer Visualization
+
+**Purpose:** Save current layer view as single composited image (4 channels side-by-side, grayscale).
+
+**Location:** HTML tool - "Layer Visualization" panel
+
+**Usage:**
+1. Load image + weights in HTML tool
+2. Select layer to visualize (Static 0-3, Static 4-7, Layer 0, Layer 1, etc.)
+3. Click "Save Composited" button
+4. Downloads PNG: `composited_layer{N}_{W}x{H}.png`
+
+**Output:**
+- 4 channels stacked horizontally
+- Grayscale representation
+- Useful for comparing layer activations across tools
+
+---
+
+## Debugging Strategy
+
+### Track a) Binary Conversion Chain
+
+**Hypothesis:** Conversion error in .bin ↔ base64 ↔ Float32Array
+
+**Test:**
+1. Generate identity weights:
+ ```bash
+ ./training/gen_identity_weights.py workspaces/main/weights/test_identity.bin
+ ```
+
+2. Load in HTML tool - output should match input RGB
+
+3. If mismatch:
+ - Check Python export: f16 packing in `export_cnn_v2_weights.py` line 105
+ - Check HTML parsing: `unpackF16()` in `index.html` line 805-815
+ - Check weight indexing: `get_weight()` shader function
+
+**Key locations:**
+- Python: `np.float16` → `view(np.uint32)` (line 105 of export script)
+- JS: `DataView` → `unpackF16()` → manual f16 decode (line 773-803)
+- WGSL: `unpack2x16float()` built-in (line 492 of shader)
+
+### Track b) Layer Visualization
+
+**Purpose:** Confirm layer outputs match between HTML and C++
+
+**Method:**
+1. Run identical input through both tools
+2. Save composited layers from HTML tool
+3. Compare with cnn_test output
+4. Use identity weights to isolate weight loading from computation
+
+### Track c) Trivial Test Case
+
+**Use identity weights to test:**
+- Weight loading (binary parsing)
+- Feature generation (static features)
+- Convolution (should be passthrough)
+- Output packing
+
+**Expected behavior:**
+- Input RGB → Output RGB (exact match)
+- Static features ignored (all zeros in identity matrix)
+
+---
+
+## Known Issues
+
+### ~~Layer 0 Visualization Scale~~ [FIXED]
+
+**Issue:** Layer 0 output displayed at 0.5× brightness (divided by 2).
+
+**Cause:** Line 1530 used `vizScale = 0.5` for all CNN layers, but Layer 0 is clamped [0,1] and doesn't need dimming.
+
+**Fix:** Use scale 1.0 for Layer 0 output (layerIdx=1), 0.5 only for middle layers (ReLU, unbounded).
+
+### Remaining Mismatch
+
+**Current:** HTML tool and cnn_test produce different outputs for same input/weights.
+
+**Suspects:**
+1. F16 unpacking difference (CPU vs GPU vs JS)
+2. Static feature generation (RGBD, UV, sin encoding)
+3. Convolution kernel iteration order
+4. Output packing/unpacking
+
+**Next steps:**
+1. Test with identity weights (eliminates weight loading)
+2. Compare composited layer outputs
+3. Add debug visualization for static features
+4. Hex dump comparison (first 8 pixels) - use `--debug-hex` flag in cnn_test
+
+---
+
+## Related Documentation
+
+- `doc/CNN_V2.md` - CNN v2 architecture
+- `doc/CNN_V2_WEB_TOOL.md` - HTML tool documentation
+- `doc/CNN_TEST_TOOL.md` - cnn_test CLI tool
+- `training/export_cnn_v2_weights.py` - Binary export format
diff --git a/cnn_v2/docs/CNN_V2_WEB_TOOL.md b/cnn_v2/docs/CNN_V2_WEB_TOOL.md
new file mode 100644
index 0000000..b6f5b0b
--- /dev/null
+++ b/cnn_v2/docs/CNN_V2_WEB_TOOL.md
@@ -0,0 +1,348 @@
+# CNN v2 Web Testing Tool
+
+Browser-based WebGPU tool for validating CNN v2 inference with layer visualization and weight inspection.
+
+**Location:** `tools/cnn_v2_test/index.html`
+
+---
+
+## Status (2026-02-13)
+
+**Working:**
+- ✅ WebGPU initialization and device setup
+- ✅ Binary weight file parsing (v1 and v2 formats)
+- ✅ Automatic mip-level detection from binary format v2
+- ✅ Weight statistics (min/max per layer)
+- ✅ UI layout with collapsible panels
+- ✅ Mode switching (Activations/Weights tabs)
+- ✅ Canvas context management (2D for weights, WebGPU for activations)
+- ✅ Weight visualization infrastructure (layer selection, grid layout)
+- ✅ Layer naming matches codebase convention (Layer 0, Layer 1, Layer 2)
+- ✅ Static features split visualization (Static 0-3, Static 4-7)
+- ✅ All layers visible including output layer (Layer 2)
+- ✅ Video playback support (MP4, WebM) with frame-by-frame controls
+- ✅ Video looping (automatic continuous playback)
+- ✅ Mip level selection (p0-p3 features at different resolutions)
+
+**Recent Changes (Latest):**
+- Binary format v2 support: Reads mip_level from 20-byte header
+- Backward compatible: v1 (16-byte header) → mip_level=0
+- Auto-update UI dropdown when loading weights with mip_level
+- Display mip_level in metadata panel
+- Code refactoring: Extracted FULLSCREEN_QUAD_VS shader (reused 3× across pipelines)
+- Added helper methods: `getDimensions()`, `setVideoControlsEnabled()`
+- Improved code organization with section headers and comments
+- Moved Mip Level selector to bottom of left sidebar (removed "Features (p0-p3)" label)
+- Added `loop` attribute to video element for automatic continuous playback
+
+**Previous Fixes:**
+- Fixed Layer 2 not appearing (was excluded from layerOutputs due to isOutput check)
+- Fixed canvas context switching (force clear before recreation)
+- Added Static 0-3 / Static 4-7 buttons to view all 8 static feature channels
+- Aligned naming with train_cnn_v2.py/.wgsl: Layer 0, Layer 1, Layer 2 (not Layer 1, 2, 3)
+- Disabled Static buttons in weights mode (no learnable weights)
+
+**Known Issues:**
+- Layer activation visualization may show black if texture data not properly unpacked
+- Weight kernel display depends on correct 2D context creation after canvas recreation
+
+---
+
+## Architecture
+
+### File Structure
+- Single-file HTML tool (~1100 lines)
+- Embedded shaders: STATIC_SHADER, CNN_SHADER, DISPLAY_SHADER, LAYER_VIZ_SHADER
+- Shared WGSL component: FULLSCREEN_QUAD_VS (reused across render pipelines)
+- **Embedded default weights:** DEFAULT_WEIGHTS_B64 (base64-encoded binary v2)
+ - Current: 4 layers (3×3, 5×5, 3×3, 3×3), 2496 f16 weights, mip_level=2
+ - Source: `workspaces/main/weights/cnn_v2_weights.bin`
+ - Updates: Re-encode binary with `base64 -i <file>` and update constant
+- Pure WebGPU (no external dependencies)
+
+### Code Organization
+
+**Recent Refactoring (2026-02-13):**
+- Extracted `FULLSCREEN_QUAD_VS` constant: Reused fullscreen quad vertex shader (2 triangles covering NDC)
+- Added helper methods to CNNTester class:
+ - `getDimensions()`: Returns current source dimensions (video or image)
+ - `setVideoControlsEnabled(enabled)`: Centralized video control enable/disable
+- Consolidated duplicate vertex shader code (used in mipmap generation, display, layer visualization)
+- Added section headers in JavaScript for better navigation
+- Improved inline comments explaining shader architecture
+
+**Benefits:**
+- Reduced code duplication (~40 lines saved)
+- Easier maintenance (single source of truth for fullscreen quad)
+- Clearer separation of concerns
+
+### Key Components
+
+**1. Weight Parsing**
+- Reads binary format v2: header (20B) + layer info (20B×N) + f16 weights
+- Backward compatible with v1: header (16B), mip_level defaults to 0
+- Computes min/max per layer via f16 unpacking
+- Stores `{ layers[], weights[], mipLevel, fileSize }`
+- Auto-sets UI mip-level dropdown from loaded weights
+
+**2. CNN Pipeline**
+- Static features computation (RGBD + UV + sin + bias → 7D packed)
+- Layer-by-layer convolution with storage buffer weights
+- Ping-pong buffers for intermediate results
+- Copy to persistent textures for visualization
+
+**3. Visualization Modes**
+
+**Activations Mode:**
+- 4 grayscale views per layer (channels 0-3 of up to 8 total)
+- WebGPU compute → unpack f16 → scale → grayscale
+- Auto-scale: Static features = 1.0, CNN layers = 0.2
+- Static features: Shows R,G,B,D (first 4 of 8: RGBD+UV+sin+bias)
+- CNN layers: Shows first 4 output channels
+
+**Weights Mode:**
+- 2D canvas rendering per output channel
+- Shows all input kernels horizontally
+- Normalized by layer min/max → [0, 1] → grayscale
+- 20px cells, 2px padding between kernels
+
+### Texture Management
+
+**Persistent Storage (layerTextures[]):**
+- One texture per layer output (static + all CNN layers)
+- `rgba32uint` format (packed f16 data)
+- `COPY_DST` usage for storing results
+
+**Compute Buffers (computeTextures[]):**
+- 2 textures for ping-pong computation
+- Reused across all layers
+- `COPY_SRC` usage for copying to persistent storage
+
+**Pipeline:**
+```
+Static pass → copy to layerTextures[0]
+For each CNN layer i:
+ Compute (ping-pong) → copy to layerTextures[i+1]
+```
+
+### Layer Indexing
+
+**UI Layer Buttons:**
+- "Static" → layerOutputs[0] (7D input features)
+- "Layer 1" → layerOutputs[1] (CNN layer 1 output, uses weights.layers[0])
+- "Layer 2" → layerOutputs[2] (CNN layer 2 output, uses weights.layers[1])
+- "Layer N" → layerOutputs[N] (CNN layer N output, uses weights.layers[N-1])
+
+**Weights Table:**
+- "Layer 1" → weights.layers[0] (first CNN layer weights)
+- "Layer 2" → weights.layers[1] (second CNN layer weights)
+- "Layer N" → weights.layers[N-1]
+
+**Consistency:** Both UI and weights table use same numbering (1, 2, 3...) for CNN layers.
+
+---
+
+## Known Issues
+
+### Issue #1: Layer Activations Show Black
+
+**Symptom:**
+- All 4 channel canvases render black
+- UV gradient test (debug mode 10) works
+- Raw packed data test (mode 11) shows black
+- Unpacked f16 test (mode 12) shows black
+
+**Diagnosis:**
+- Texture access works (UV gradient visible)
+- Texture data is all zeros (packed.x = 0)
+- Textures being read are empty
+
+**Root Cause:**
+- `copyTextureToTexture` operations may not be executing
+- Possible ordering issue (copies not submitted before visualization)
+- Alternative: textures created with wrong usage flags
+
+**Investigation Steps Taken:**
+1. Added `onSubmittedWorkDone()` wait before visualization
+2. Verified texture creation with `COPY_SRC` and `COPY_DST` flags
+3. Confirmed separate texture allocation per layer (no aliasing)
+4. Added debug shader modes to isolate issue
+
+**Next Steps:**
+- Verify encoder contains copy commands (add debug logging)
+- Check if compute passes actually write data (add known-value test)
+- Test copyTextureToTexture in isolation
+- Consider CPU readback to verify texture contents
+
+### Issue #2: Weight Visualization Empty
+
+**Symptom:**
+- Canvases created with correct dimensions (logged)
+- No visual output (black canvases)
+- Console logs show method execution
+
+**Potential Causes:**
+1. Weight indexing calculation incorrect
+2. Canvas not properly attached to DOM when rendering
+3. 2D context operations not flushing
+4. Min/max normalization producing black (all values equal?)
+
+**Debug Added:**
+- Comprehensive logging of dimensions, indices, ranges
+- Canvas context check before rendering
+
+**Next Steps:**
+- Add test rendering (fixed gradient) to verify 2D context works
+- Log sample weight values to verify data access
+- Check if canvas is visible in DOM inspector
+- Verify min/max calculation produces valid range
+
+---
+
+## UI Layout
+
+### Header
+- Controls: Blend slider, Depth input, View mode display
+- Drop zone for .bin weight files
+
+### Content Area
+
+**Left Sidebar (300px):**
+1. Drop zone for .bin weight files
+2. Weights Info panel (file size, layer table with min/max)
+3. Weights Visualization panel (per-layer kernel display)
+4. **Mip Level selector** (bottom) - Select p0/p1/p2 for static features
+
+**Main Canvas (center):**
+- CNN output display with video controls (Play/Pause, Frame ◄/►)
+- Supports both PNG images and video files (MP4, WebM)
+- Video loops automatically for continuous playback
+
+**Right Sidebar (panels):**
+1. **Layer Visualization Panel** (top, flex: 1)
+ - Layer selection buttons (Static 0-3, Static 4-7, Layer 0, Layer 1, ...)
+ - 2×2 grid of channel views (grayscale activations)
+ - 4× zoom view at bottom
+
+### Footer
+- Status line (GPU timing, dimensions, mode)
+- Console log (scrollable, color-coded)
+
+---
+
+## Shader Details
+
+### LAYER_VIZ_SHADER
+
+**Purpose:** Display single channel from packed layer texture
+
+**Inputs:**
+- `@binding(0) layer_tex: texture_2d<u32>` - Packed f16 layer data
+- `@binding(1) viz_params: vec2<f32>` - (channel_idx, scale)
+
+**Debug Modes:**
+- Channel 10: UV gradient (texture coordinate test)
+- Channel 11: Raw packed u32 data
+- Channel 12: First unpacked f16 value
+
+**Normal Operation:**
+- Unpack all 8 f16 channels from rgba32uint
+- Select channel by index (0-7)
+- Apply scale factor (1.0 for static, 0.2 for CNN)
+- Clamp to [0, 1] and output grayscale
+
+**Scale Rationale:**
+- Static features (RGBD, UV): already in [0, 1] range
+- CNN activations: post-ReLU [0, ~5], need scaling for visibility
+
+---
+
+## Binary Weight Format
+
+See `doc/CNN_V2_BINARY_FORMAT.md` for complete specification.
+
+**Quick Summary:**
+- Header: 16 bytes (magic, version, layer count, total weights)
+- Layer info: 20 bytes × N (kernel size, channels, offsets)
+- Weights: Packed f16 pairs as u32
+
+---
+
+## Testing Workflow
+
+### Load & Parse
+1. Drop PNG image → displays original
+2. Drop .bin weights → parses and shows info table
+3. Auto-runs CNN pipeline
+
+### Verify Pipeline
+1. Check console for "Running CNN pipeline"
+2. Verify "Completed in Xms"
+3. Check "Layer visualization ready: N layers"
+
+### Debug Activations
+1. Select "Activations" tab
+2. Click layer buttons to switch
+3. Check console for texture/canvas logs
+4. If black: note which debug modes work (UV vs data)
+
+### Debug Weights
+1. Select "Weights" tab
+2. Click Layer 1 or Layer 2 (Layer 0 has no weights)
+3. Check console for "Visualizing Layer N weights"
+4. Check canvas dimensions logged
+5. Verify weight range is non-trivial (not [0, 0])
+
+---
+
+## Integration with Main Project
+
+**Training Pipeline:**
+```bash
+# Generate weights
+./training/train_cnn_v2.py --export-binary
+
+# Test in browser
+open tools/cnn_v2_test/index.html
+# Drop: workspaces/main/cnn_v2_weights.bin
+# Drop: training/input/test.png
+```
+
+**Validation:**
+- Compare against demo CNNv2Effect (visual check)
+- Verify layer count matches binary file
+- Check weight ranges match training logs
+
+---
+
+## Future Enhancements
+
+- [ ] Fix layer activation visualization (black texture issue)
+- [ ] Fix weight kernel display (empty canvas issue)
+- [ ] Add per-channel auto-scaling (compute min/max from visible data)
+- [ ] Export rendered outputs (download PNG)
+- [ ] Side-by-side comparison with original
+- [ ] Heatmap mode (color-coded activations)
+- [ ] Weight statistics overlay (mean, std, sparsity)
+- [ ] Batch processing (multiple images in sequence)
+- [ ] Integration with Python training (live reload)
+
+---
+
+## Code Metrics
+
+- Total lines: ~1100
+- JavaScript: ~700 lines
+- WGSL shaders: ~300 lines
+- HTML/CSS: ~100 lines
+
+**Dependencies:** None (pure WebGPU + HTML5)
+
+---
+
+## Related Files
+
+- `doc/CNN_V2.md` - CNN v2 architecture and design
+- `doc/CNN_TEST_TOOL.md` - C++ offline testing tool (deprecated)
+- `training/train_cnn_v2.py` - Training script with binary export
+- `workspaces/main/cnn_v2_weights.bin` - Trained weights
diff --git a/cnn_v2/scripts/train_cnn_v2_full.sh b/cnn_v2/scripts/train_cnn_v2_full.sh
new file mode 100755
index 0000000..a21c1ac
--- /dev/null
+++ b/cnn_v2/scripts/train_cnn_v2_full.sh
@@ -0,0 +1,428 @@
+#!/bin/bash
+# Complete CNN v2 Training Pipeline
+# Train → Export → Build → Validate
+# Usage: ./train_cnn_v2_full.sh [OPTIONS]
+#
+# MODES:
+# (none) Run complete pipeline: train → export → build → validate
+# --validate Validate only (skip training, use existing weights)
+# --validate CHECKPOINT Validate with specific checkpoint file
+# --export-only CHECKPOINT Export weights only (skip training, build, validation)
+#
+# TRAINING PARAMETERS:
+# --epochs N Training epochs (default: 200)
+# --batch-size N Batch size (default: 16)
+# --lr FLOAT Learning rate (default: 1e-3)
+# --checkpoint-every N Checkpoint interval (default: 50)
+# --kernel-sizes K Comma-separated kernel sizes (default: 3,3,3)
+# --num-layers N Number of layers (default: 3)
+# --mip-level N Mip level for p0-p3 features: 0-3 (default: 0)
+# --grayscale-loss Compute loss on grayscale instead of RGBA
+#
+# PATCH PARAMETERS:
+# --patch-size N Patch size (default: 8)
+# --patches-per-image N Patches per image (default: 256)
+# --detector TYPE Detector: harris|fast|shi-tomasi|gradient (default: harris)
+# --full-image Use full-image training (disables patch mode)
+# --image-size N Image size for full-image mode (default: 256)
+#
+# DIRECTORIES:
+# --input DIR Input directory (default: training/input)
+# --target DIR Target directory (default: training/target_1)
+# --checkpoint-dir DIR Checkpoint directory (default: checkpoints)
+# --validation-dir DIR Validation directory (default: validation_results)
+#
+# OUTPUT:
+# --output-weights PATH Output binary weights file (default: workspaces/main/weights/cnn_v2_weights.bin)
+#
+# OTHER:
+# --help Show this help message
+#
+# Examples:
+# ./train_cnn_v2_full.sh
+# ./train_cnn_v2_full.sh --epochs 500 --batch-size 32
+# ./train_cnn_v2_full.sh --validate
+# ./train_cnn_v2_full.sh --validate checkpoints/checkpoint_epoch_50.pth
+# ./train_cnn_v2_full.sh --export-only checkpoints/checkpoint_epoch_100.pth
+# ./train_cnn_v2_full.sh --mip-level 1 --kernel-sizes 3,5,3
+
+set -e
+
+PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$PROJECT_ROOT"
+
+# Helper functions
+export_weights() {
+ python3 "$SCRIPT_DIR/../training/export_cnn_v2_weights.py" "$1" --output-weights "$2" --quiet
+}
+
+find_latest_checkpoint() {
+ ls -t "$CHECKPOINT_DIR"/checkpoint_epoch_*.pth 2>/dev/null | head -1
+}
+
+build_target() {
+ cmake --build build -j4 --target "$1" > /dev/null 2>&1
+}
+
+# Path resolution for running from any directory
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
+
+# Default configuration
+INPUT_DIR="training/input"
+TARGET_DIR="training/target_1"
+CHECKPOINT_DIR="checkpoints"
+VALIDATION_DIR="validation_results"
+EPOCHS=200
+CHECKPOINT_EVERY=50
+BATCH_SIZE=16
+LEARNING_RATE=1e-3
+PATCH_SIZE=8
+PATCHES_PER_IMAGE=256
+DETECTOR="harris"
+KERNEL_SIZES="3,3,3"
+NUM_LAYERS=3
+MIP_LEVEL=0
+GRAYSCALE_LOSS=false
+FULL_IMAGE_MODE=false
+IMAGE_SIZE=256
+OUTPUT_WEIGHTS="${PROJECT_ROOT}/workspaces/main/weights/cnn_v2_weights.bin"
+
+# Parse arguments
+VALIDATE_ONLY=false
+VALIDATE_CHECKPOINT=""
+EXPORT_ONLY=false
+EXPORT_CHECKPOINT=""
+
+if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
+ head -47 "$0" | grep "^#" | grep -v "^#!/" | sed 's/^# *//'
+ exit 0
+fi
+
+# Parse all arguments
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --export-only)
+ EXPORT_ONLY=true
+ if [ -z "$2" ]; then
+ echo "Error: --export-only requires a checkpoint file argument"
+ exit 1
+ fi
+ EXPORT_CHECKPOINT="$2"
+ shift 2
+ ;;
+ --validate)
+ VALIDATE_ONLY=true
+ if [ -n "$2" ] && [[ ! "$2" =~ ^-- ]]; then
+ VALIDATE_CHECKPOINT="$2"
+ shift 2
+ else
+ shift
+ fi
+ ;;
+ --epochs)
+ if [ -z "$2" ]; then
+ echo "Error: --epochs requires a number argument"
+ exit 1
+ fi
+ EPOCHS="$2"
+ shift 2
+ ;;
+ --batch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --batch-size requires a number argument"
+ exit 1
+ fi
+ BATCH_SIZE="$2"
+ shift 2
+ ;;
+ --checkpoint-every)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-every requires a number argument"
+ exit 1
+ fi
+ CHECKPOINT_EVERY="$2"
+ shift 2
+ ;;
+ --kernel-sizes)
+ if [ -z "$2" ]; then
+ echo "Error: --kernel-sizes requires a comma-separated list"
+ exit 1
+ fi
+ KERNEL_SIZES="$2"
+ shift 2
+ ;;
+ --num-layers)
+ if [ -z "$2" ]; then
+ echo "Error: --num-layers requires a number argument"
+ exit 1
+ fi
+ NUM_LAYERS="$2"
+ shift 2
+ ;;
+ --mip-level)
+ if [ -z "$2" ]; then
+ echo "Error: --mip-level requires a level argument (0-3)"
+ exit 1
+ fi
+ MIP_LEVEL="$2"
+ shift 2
+ ;;
+ --grayscale-loss)
+ GRAYSCALE_LOSS=true
+ shift
+ ;;
+ --lr)
+ if [ -z "$2" ]; then
+ echo "Error: --lr requires a float argument"
+ exit 1
+ fi
+ LEARNING_RATE="$2"
+ shift 2
+ ;;
+ --patch-size)
+ if [ -z "$2" ]; then
+ echo "Error: --patch-size requires a number argument"
+ exit 1
+ fi
+ PATCH_SIZE="$2"
+ shift 2
+ ;;
+ --patches-per-image)
+ if [ -z "$2" ]; then
+ echo "Error: --patches-per-image requires a number argument"
+ exit 1
+ fi
+ PATCHES_PER_IMAGE="$2"
+ shift 2
+ ;;
+ --detector)
+ if [ -z "$2" ]; then
+ echo "Error: --detector requires a type argument"
+ exit 1
+ fi
+ DETECTOR="$2"
+ shift 2
+ ;;
+ --full-image)
+ FULL_IMAGE_MODE=true
+ shift
+ ;;
+ --image-size)
+ if [ -z "$2" ]; then
+ echo "Error: --image-size requires a number argument"
+ exit 1
+ fi
+ IMAGE_SIZE="$2"
+ shift 2
+ ;;
+ --input)
+ if [ -z "$2" ]; then
+ echo "Error: --input requires a directory argument"
+ exit 1
+ fi
+ INPUT_DIR="$2"
+ shift 2
+ ;;
+ --target)
+ if [ -z "$2" ]; then
+ echo "Error: --target requires a directory argument"
+ exit 1
+ fi
+ TARGET_DIR="$2"
+ shift 2
+ ;;
+ --checkpoint-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --checkpoint-dir requires a directory argument"
+ exit 1
+ fi
+ CHECKPOINT_DIR="$2"
+ shift 2
+ ;;
+ --validation-dir)
+ if [ -z "$2" ]; then
+ echo "Error: --validation-dir requires a directory argument"
+ exit 1
+ fi
+ VALIDATION_DIR="$2"
+ shift 2
+ ;;
+ --output-weights)
+ if [ -z "$2" ]; then
+ echo "Error: --output-weights requires a file path argument"
+ exit 1
+ fi
+ OUTPUT_WEIGHTS="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
+# Build training arguments
+if [ "$FULL_IMAGE_MODE" = true ]; then
+ TRAINING_MODE_ARGS="--full-image --image-size $IMAGE_SIZE"
+else
+ TRAINING_MODE_ARGS="--patch-size $PATCH_SIZE --patches-per-image $PATCHES_PER_IMAGE --detector $DETECTOR"
+fi
+
+# Handle export-only mode
+if [ "$EXPORT_ONLY" = true ]; then
+ echo "=== CNN v2 Export Weights Only ==="
+ echo "Checkpoint: $EXPORT_CHECKPOINT"
+ echo ""
+
+ if [ ! -f "$EXPORT_CHECKPOINT" ]; then
+ echo "Error: Checkpoint file not found: $EXPORT_CHECKPOINT"
+ exit 1
+ fi
+
+ export_weights "$EXPORT_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
+ echo "Error: Export failed"
+ exit 1
+ }
+
+ echo ""
+ echo "=== Export Complete ==="
+ echo "Output: $OUTPUT_WEIGHTS"
+ exit 0
+fi
+
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== CNN v2 Validation Only ==="
+ echo "Skipping training, using existing weights"
+ echo ""
+else
+ echo "=== CNN v2 Complete Training Pipeline ==="
+ echo "Input: $INPUT_DIR"
+ echo "Target: $TARGET_DIR"
+ echo "Epochs: $EPOCHS"
+ echo "Checkpoint interval: $CHECKPOINT_EVERY"
+ echo "Mip level: $MIP_LEVEL (p0-p3 features)"
+ echo ""
+fi
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ # Step 1: Train model
+ echo "[1/4] Training CNN v2 model..."
+
+python3 "$SCRIPT_DIR/../training/train_cnn_v2.py" \
+ --input "$INPUT_DIR" \
+ --target "$TARGET_DIR" \
+ $TRAINING_MODE_ARGS \
+ --kernel-sizes "$KERNEL_SIZES" \
+ --num-layers "$NUM_LAYERS" \
+ --mip-level "$MIP_LEVEL" \
+ --epochs "$EPOCHS" \
+ --batch-size "$BATCH_SIZE" \
+ --lr "$LEARNING_RATE" \
+ --checkpoint-dir "$CHECKPOINT_DIR" \
+ --checkpoint-every "$CHECKPOINT_EVERY" \
+ $([ "$GRAYSCALE_LOSS" = true ] && echo "--grayscale-loss")
+
+if [ $? -ne 0 ]; then
+ echo "Error: Training failed"
+ exit 1
+fi
+
+echo ""
+echo "Training complete!"
+echo ""
+
+# Step 2: Export final checkpoint to shaders
+FINAL_CHECKPOINT="$CHECKPOINT_DIR/checkpoint_epoch_${EPOCHS}.pth"
+
+if [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Warning: Final checkpoint not found, using latest available..."
+ FINAL_CHECKPOINT=$(find_latest_checkpoint)
+fi
+
+if [ -z "$FINAL_CHECKPOINT" ] || [ ! -f "$FINAL_CHECKPOINT" ]; then
+ echo "Error: No checkpoint found in $CHECKPOINT_DIR"
+ exit 1
+fi
+
+echo "[2/4] Exporting final checkpoint to binary weights..."
+echo "Checkpoint: $FINAL_CHECKPOINT"
+export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" || {
+ echo "Error: Shader export failed"
+ exit 1
+}
+
+echo ""
+fi # End of training/export section
+
+# Determine which checkpoint to use
+if [ "$VALIDATE_ONLY" = true ]; then
+ FINAL_CHECKPOINT="${VALIDATE_CHECKPOINT:-$(find_latest_checkpoint)}"
+ echo "Using checkpoint: $FINAL_CHECKPOINT"
+ echo ""
+fi
+
+# Step 3: Rebuild with new shaders
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo "[3/4] Rebuilding demo with new shaders..."
+ build_target demo64k || {
+ echo "Error: Build failed"
+ exit 1
+ }
+ echo " → Build complete"
+ echo ""
+fi
+
+# Step 4: Visual assessment - process final checkpoint only
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "Validation on all input images (using existing weights)..."
+else
+ echo "[4/4] Visual assessment on all input images..."
+fi
+
+mkdir -p "$VALIDATION_DIR"
+echo " Using checkpoint: $FINAL_CHECKPOINT"
+
+# Export weights for validation mode (already exported in step 2 for training mode)
+if [ "$VALIDATE_ONLY" = true ]; then
+ export_weights "$FINAL_CHECKPOINT" "$OUTPUT_WEIGHTS" > /dev/null 2>&1
+fi
+
+# Build cnn_test
+build_target cnn_test
+
+# Process all input images
+echo -n " Processing images: "
+for input_image in "$INPUT_DIR"/*.png; do
+ basename=$(basename "$input_image" .png)
+ echo -n "$basename "
+ build/cnn_test "$input_image" "$VALIDATION_DIR/${basename}_output.png" --weights "$OUTPUT_WEIGHTS" > /dev/null 2>&1
+done
+echo "✓"
+
+# Build demo only if not in validate mode
+[ "$VALIDATE_ONLY" = false ] && build_target demo64k
+
+echo ""
+if [ "$VALIDATE_ONLY" = true ]; then
+ echo "=== Validation Complete ==="
+else
+ echo "=== Training Pipeline Complete ==="
+fi
+echo ""
+echo "Results:"
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo " - Checkpoints: $CHECKPOINT_DIR"
+ echo " - Final weights: $OUTPUT_WEIGHTS"
+fi
+echo " - Validation outputs: $VALIDATION_DIR"
+echo ""
+echo "Opening results directory..."
+open "$VALIDATION_DIR" 2>/dev/null || xdg-open "$VALIDATION_DIR" 2>/dev/null || true
+
+if [ "$VALIDATE_ONLY" = false ]; then
+ echo ""
+ echo "Run demo to see final result:"
+ echo " ./build/demo64k"
+fi
diff --git a/cnn_v2/shaders/cnn_v2_compute.wgsl b/cnn_v2/shaders/cnn_v2_compute.wgsl
new file mode 100644
index 0000000..cdbfd74
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_compute.wgsl
@@ -0,0 +1,143 @@
+// CNN v2 Compute Shader - Uniform 12D→4D Architecture
+// All layers: input/previous (4D) + static (8D) = 12D → 4 channels
+// Storage buffer weights, ping-pong execution
+// Per-layer kernel sizes supported via LayerParams
+
+// Push constants for layer parameters (passed per dispatch)
+struct LayerParams {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in f16 units
+ is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu)
+ blend_amount: f32, // [0,1] blend with original
+ is_layer_0: u32, // 1 if first layer (clamp [0,1]), 0 otherwise
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>; // 8D static features (p0-p3 + spatial)
+@group(0) @binding(1) var layer_input: texture_2d<u32>; // 4D previous/input (RGBD or prev layer)
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // 4D output
+@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
+@group(0) @binding(4) var<uniform> params: LayerParams;
+@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB for blending
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
+}
+
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
+ );
+}
+
+// Get weight from storage buffer (f16 packed as u32 pairs)
+// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
+// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction
+fn get_weight(idx: u32) -> f32 {
+ // Skip header (16 bytes = 4 u32) and layer info
+ // Weights start after header + layer_info, but weight_offset already accounts for this
+ let pair_idx = idx / 2u;
+ let packed = weights_buffer[pair_idx];
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ let kernel_size = params.kernel_size;
+ let in_channels = params.in_channels; // Always 12 (4 prev + 8 static)
+ let out_channels = params.out_channels; // Always 4
+ let weight_offset = params.weight_offset;
+ let is_output = params.is_output_layer != 0u;
+
+ let kernel_radius = i32(kernel_size / 2u);
+
+ // Load static features (8D) and previous/input layer (4D)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution: 12D input → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
+ var sum: f32 = 0.0;
+
+ // Convolve over kernel
+ for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
+ for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load features at this spatial location
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
+
+ // Weight index calculation
+ let ky_idx = u32(ky + kernel_radius);
+ let kx_idx = u32(kx + kernel_radius);
+ let spatial_idx = ky_idx * kernel_size + kx_idx;
+
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {
+ let w_idx = weight_offset +
+ c * 12u * kernel_size * kernel_size +
+ i * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * layer_local[i];
+ }
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * 12u * kernel_size * kernel_size +
+ (4u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
+ }
+ }
+ }
+
+ // Activation (matches train_cnn_v2.py)
+ if (is_output || params.is_layer_0 != 0u) {
+ output[c] = 1.0 / (1.0 + exp(-sum)); // Sigmoid [0,1]
+ } else {
+ output[c] = max(0.0, sum); // ReLU
+ }
+ }
+
+ // Blend with original on final layer
+ if (is_output) {
+ let original = textureLoad(original_input, coord, 0).rgb;
+ let result_rgb = vec3<f32>(output.x, output.y, output.z);
+ let blended = mix(original, result_rgb, params.blend_amount);
+ output.x = blended.r;
+ output.y = blended.g;
+ output.z = blended.b;
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}
diff --git a/cnn_v2/shaders/cnn_v2_layer_0.wgsl b/cnn_v2/shaders/cnn_v2_layer_0.wgsl
new file mode 100644
index 0000000..8e14957
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_layer_0.wgsl
@@ -0,0 +1,174 @@
+// CNN v2 Layer 0 - Auto-generated
+// Kernel: 3×3, In: 8, Out: 8
+
+const KERNEL_SIZE: u32 = 3u;
+const IN_CHANNELS: u32 = 8u;
+const OUT_CHANNELS: u32 = 8u;
+const KERNEL_RADIUS: i32 = 1;
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, 576> = array(
+ 0.057281, -0.041962, 0.003933, 0.026459, 0.304199, 0.067261, 0.191895, 0.047455,
+ 0.074402, 0.201660, 0.158325, 0.150513, 0.219238, 0.260010, 0.319336, 0.208618,
+ 0.050201, 0.090210, 0.086853, 0.181152, 0.060486, 0.167847, 0.161499, 0.265869,
+ 0.163818, 0.100647, 0.243408, -0.008553, -0.010849, 0.046509, -0.060608, -0.022263,
+ 0.094360, -0.043854, -0.005329, -0.093262, 0.032349, 0.007259, 0.039948, -0.018692,
+ -0.000618, 0.052368, -0.038055, 0.118042, -0.084595, 0.044281, -0.107056, 0.089478,
+ -0.076477, 0.017441, 0.088135, 0.076721, -0.063965, 0.001612, 0.062469, 0.067505,
+ 0.035736, 0.115051, -0.117737, -0.076843, -0.008888, -0.002028, -0.061005, 0.081726,
+ 0.115051, -0.028183, 0.043213, -0.079285, -0.040314, -0.047699, -0.051575, -0.052521,
+ 0.071533, 0.084656, 0.051910, 0.090637, -0.104248, -0.066467, -0.032104, -0.006977,
+ 0.075439, -0.004841, 0.084656, -0.034698, 0.035675, -0.101929, -0.035034, -0.036804,
+ 0.069641, -0.026840, -0.017807, -0.088318, -0.125000, -0.042847, -0.003063, 0.007622,
+ 0.076416, 0.094971, -0.019058, 0.083496, -0.085205, 0.036285, -0.077209, 0.082458,
+ 0.056549, 0.038818, 0.092224, -0.002499, 0.069641, 0.097229, 0.069275, -0.111084,
+ -0.092041, -0.020462, -0.061279, -0.032196, -0.088623, 0.032227, -0.117004, -0.125854,
+ -0.015884, 0.093018, -0.070923, -0.117615, -0.081848, -0.115479, 0.033508, -0.026443,
+ -0.009850, -0.063232, 0.098328, -0.000984, 0.039886, -0.085754, -0.108826, 0.030258,
+ 0.091675, 0.024384, -0.118958, -0.077148, -0.122437, -0.002090, -0.089539, 0.096741,
+ 0.095337, 0.108582, -0.101807, 0.152222, 0.206177, 0.050323, -0.111450, -0.104431,
+ -0.037445, 0.276611, 0.244019, 0.171143, 0.131592, 0.056030, 0.141602, 0.014267,
+ -0.025955, -0.019730, 0.155884, 0.072144, 0.176636, -0.010117, 0.141724, 0.103027,
+ -0.253174, -0.229370, -0.105713, -0.005898, 0.075439, -0.002014, -0.010506, -0.108093,
+ -0.016724, 0.108215, 0.053589, -0.044586, 0.030396, -0.077759, 0.058594, -0.018463,
+ 0.027100, 0.030823, -0.026947, -0.014084, 0.121643, 0.116638, -0.010239, 0.106262,
+ -0.109070, -0.044281, -0.045319, -0.021942, 0.083923, 0.114929, 0.154541, 0.078186,
+ -0.047394, 0.007957, 0.099182, -0.030075, 0.103699, 0.080994, -0.085144, 0.047180,
+ 0.099792, 0.081116, 0.084961, 0.151123, 0.000963, 0.029221, 0.073181, 0.086609,
+ 0.149048, -0.052185, -0.158936, 0.146240, 0.020004, 0.063110, 0.111877, 0.037201,
+ 0.087585, 0.134277, 0.058258, -0.075256, 0.141357, 0.045776, 0.171753, 0.186035,
+ 0.093201, 0.202637, 0.018723, -0.047638, 0.072510, 0.132812, 0.182251, 0.191650,
+ 0.163818, 0.146362, 0.124451, -0.082214, 0.094482, -0.007275, 0.029099, -0.040314,
+ -0.017624, -0.018860, -0.108398, -0.111145, 0.058289, -0.106995, -0.091919, 0.069824,
+ -0.084045, -0.105957, 0.065002, -0.012894, 0.042297, -0.081299, -0.112976, 0.012314,
+ 0.015625, -0.100708, -0.039673, 0.092041, 0.037201, 0.089722, 0.064087, 0.000403,
+ 0.120667, -0.012238, -0.055695, 0.010620, -0.022110, -0.008751, 0.038605, 0.075256,
+ 0.041260, 0.128296, -0.072021, 0.020828, -0.072449, 0.051239, 0.034058, 0.122803,
+ -0.062103, 0.156006, -0.111633, 0.043671, 0.209229, 0.006088, 0.141968, 0.209961,
+ 0.122620, -0.004547, 0.107727, 0.115601, 0.003378, 0.375732, 0.068481, 0.037842,
+ 0.159546, -0.014450, 0.073425, 0.168701, -0.052643, 0.060699, 0.333740, 0.033905,
+ -0.060150, 0.053558, 0.165527, -0.052460, -0.047882, 0.080750, 0.110352, -0.057098,
+ 0.057983, -0.018692, 0.019714, -0.056427, -0.053314, -0.001763, 0.027039, 0.003395,
+ -0.131226, -0.068481, -0.086609, 0.065186, 0.084717, 0.036530, 0.043488, 0.013893,
+ -0.076660, 0.081177, 0.037476, -0.124084, -0.070312, -0.027130, -0.009331, -0.128174,
+ -0.075256, 0.098206, -0.046539, -0.045319, 0.083923, -0.050598, 0.063477, 0.007408,
+ 0.026794, -0.090454, -0.083435, 0.129761, 0.044556, 0.051849, 0.115662, 0.071167,
+ 0.004414, 0.048035, -0.148682, 0.098938, 0.200562, 0.111938, 0.208496, 0.200684,
+ -0.050262, 0.119568, 0.062988, 0.072083, 0.123779, 0.369629, 0.317627, 0.187622,
+ 0.157227, 0.183960, 0.031921, 0.142944, 0.080627, 0.218628, 0.264160, 0.156128,
+ 0.084961, 0.029343, 0.057617, 0.089233, 0.041138, 0.044373, 0.074707, 0.025818,
+ 0.113708, -0.045380, -0.114929, 0.104370, -0.012238, -0.174194, -0.169312, -0.070312,
+ -0.005863, 0.027481, 0.053345, -0.016006, -0.057953, -0.010284, 0.034241, -0.041077,
+ -0.002373, 0.034515, 0.078552, -0.066162, -0.035400, 0.072510, 0.060425, -0.037720,
+ -0.025955, 0.118042, -0.071777, 0.133667, 0.012192, -0.080933, 0.093445, 0.052826,
+ -0.037354, -0.052277, 0.124084, 0.029861, 0.137085, 0.053009, -0.034180, -0.011421,
+ 0.089233, 0.172729, 0.146118, 0.003944, 0.279541, 0.162842, 0.112244, 0.204956,
+ 0.059753, 0.117737, 0.330322, 0.185547, 0.194946, 0.404541, 0.274658, 0.177612,
+ 0.153320, 0.189575, 0.032257, 0.285400, 0.158203, 0.048035, 0.476562, 0.301025,
+ -0.179565, 0.160767, 0.137207, 0.102478, -0.060547, 0.060364, -0.091858, 0.064209,
+ 0.082642, 0.044769, -0.096436, -0.103699, -0.021683, 0.007221, -0.048737, 0.071228,
+ -0.069580, 0.066528, -0.122864, -0.008415, -0.094788, 0.040131, -0.091431, -0.029602,
+ -0.112488, -0.074158, -0.004898, -0.006721, -0.118286, -0.047516, 0.069519, 0.121521,
+ -0.004158, 0.167603, -0.092468, -0.049927, 0.006599, 0.097595, 0.064087, 0.083435,
+ 0.026993, 0.071411, 0.020538, 0.022293, 0.022858, 0.124268, 0.098999, -0.031738,
+ 0.019806, -0.087341, -0.096558, -0.099304, -0.113159, 0.021744, -0.080200, -0.056030,
+ 0.089661, -0.055115, -0.115845, -0.040222, 0.035919, 0.027832, 0.034668, 0.072632,
+ 0.071838, -0.081116, 0.050262, -0.037872, 0.054047, -0.096680, -0.102051, -0.044281,
+ 0.078796, -0.095154, -0.013229, 0.031555, -0.058533, -0.114441, -0.008530, 0.112732,
+ -0.057251, 0.096191, -0.008385, 0.052246, -0.016983, 0.092041, 0.013710, 0.012299,
+ -0.109497, 0.025604, -0.121643, -0.023819, 0.039490, -0.090088, -0.013145, -0.101562,
+ -0.115051, 0.050232, -0.047119, -0.055847, -0.017563, 0.103760, 0.116333, -0.061768,
+ -0.083069, -0.030319, 0.078003, -0.010124, 0.044617, -0.045868, 0.103638, 0.032379,
+ -0.093506, -0.048004, -0.022079, -0.004353, -0.048187, -0.025330, -0.070740, -0.014671
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }
+ }
+ }
+
+ output[c] = max(0.0, sum); // ReLU
+ }
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}
diff --git a/cnn_v2/shaders/cnn_v2_layer_1.wgsl b/cnn_v2/shaders/cnn_v2_layer_1.wgsl
new file mode 100644
index 0000000..f490d13
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_layer_1.wgsl
@@ -0,0 +1,174 @@
+// CNN v2 Layer 1 - Auto-generated
+// Kernel: 3×3, In: 16, Out: 4
+
+const KERNEL_SIZE: u32 = 3u;
+const IN_CHANNELS: u32 = 16u;
+const OUT_CHANNELS: u32 = 4u;
+const KERNEL_RADIUS: i32 = 1;
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, 576> = array(
+ 0.337402, 0.638672, -0.481201, 0.699707, 1.127930, -0.018280, -0.062195, 0.148682,
+ -0.655273, 0.448975, 0.969238, -0.280762, 0.817383, 1.271484, 0.421387, -0.163696,
+ 0.305664, -0.454834, 0.354004, 0.932617, -0.411377, 0.581543, 1.263672, 0.422363,
+ -0.380371, 0.152588, -0.668945, -0.063782, 0.060730, 0.022018, -0.075195, -0.049286,
+ 0.068542, 0.057343, -0.009773, 0.006344, -0.080872, -0.179932, -0.297119, 0.098328,
+ 0.061951, -0.088989, 0.047913, 0.093628, -0.091858, -0.068298, 0.102600, -0.044067,
+ -0.054230, -0.031799, 0.050934, -0.300049, -0.202637, -0.203613, -0.294189, -0.361084,
+ 0.277344, -0.213257, -0.239624, 0.193237, -0.215210, -0.295166, 0.298828, -0.065369,
+ 0.148926, 0.024963, 0.272705, 0.368164, 0.173096, 0.061279, 0.291260, 0.151611,
+ 0.411133, 0.216431, -0.179932, 0.506348, 0.319580, 0.059875, -0.134399, -0.150635,
+ -0.275391, 0.029480, 0.115417, 0.063782, 0.018723, -0.073364, -0.019653, 0.066467,
+ -0.086731, 0.113220, 0.110535, 0.011940, -0.094727, 0.262207, 0.180298, 0.141357,
+ 0.249634, 0.199585, 0.120605, 0.403809, 0.242676, -0.028442, 0.251953, 0.130737,
+ 0.152832, -0.306396, -0.324951, -0.176514, 0.161133, 0.333252, -0.195068, 0.250244,
+ 0.569824, 0.011223, -0.186035, 0.048279, -0.325439, 0.272217, 0.144043, -0.142700,
+ 0.447754, 0.434082, 0.124878, -0.157471, -0.120422, -0.281494, 0.338135, 0.266113,
+ -0.301514, 0.424805, 0.541504, -0.195679, 0.054962, 0.061798, -0.323975, 0.056732,
+ 0.072571, -0.087341, 0.052856, -0.057220, 0.023270, 0.071472, 0.014038, 0.083008,
+ -0.050659, 0.020111, 0.035614, -0.038086, -0.042786, 0.060242, -0.050079, -0.044403,
+ -0.059631, 0.075500, 0.056000, 0.010910, -0.064026, -0.016037, -0.050720, 0.050171,
+ -0.075256, -0.014183, 0.047058, -0.086731, 0.027939, 0.063232, -0.024597, -0.039551,
+ 0.000622, -0.048370, -0.001906, 0.058868, -0.074524, 0.019714, -0.036011, 0.028442,
+ 0.009766, -0.060577, -0.007416, -0.014381, 0.002317, -0.023483, 0.014313, 0.057434,
+ 0.063110, 0.030350, -0.027557, 0.023270, 0.055115, -0.003502, 0.012268, -0.054993,
+ -0.084961, -0.022736, 0.076233, 0.027573, -0.068787, -0.036987, -0.018539, -0.049347,
+ 0.032227, 0.033081, 0.050476, 0.043030, 0.023636, -0.039764, -0.018600, 0.073669,
+ 0.032166, -0.047119, -0.033325, -0.038605, 0.034119, -0.076843, 0.005863, -0.049103,
+ 0.065796, -0.056458, 0.054504, -0.008354, -0.018509, -0.057739, -0.075684, -0.053680,
+ 0.036804, 0.020721, -0.056183, 0.021774, -0.043884, 0.033661, -0.029633, 0.027374,
+ -0.087891, 0.030853, -0.040070, 0.013733, -0.082275, -0.072571, -0.055756, 0.002262,
+ 0.004421, -0.012169, -0.078064, -0.063904, -0.051758, -0.033264, -0.059265, -0.062256,
+ 0.063782, -0.088745, -0.026855, 0.062805, -0.036591, 0.037659, -0.012970, 0.025513,
+ -0.000908, 0.027084, 0.001842, -0.080750, -0.049713, -0.069397, -0.046448, -0.031006,
+ 0.012543, 0.009369, -0.080139, -0.034363, 0.003361, -0.052704, 0.041870, 0.059265,
+ 0.029938, 0.000138, 0.049896, 0.068787, 0.040405, -0.073608, 0.047668, 0.015320,
+ -0.033203, -0.016983, 0.034149, -0.010323, 0.029877, 0.078003, -0.054688, -0.021805,
+ -0.019409, 0.010284, 0.089172, -0.050385, 0.024857, -0.041992, 0.016602, 0.082397,
+ 0.081970, 0.096375, 0.060760, -0.006603, 0.029907, 0.012131, 0.104980, 0.034210,
+ 0.074707, -0.028320, -0.020248, 0.114868, -0.036957, 0.040192, 0.002888, 0.034973,
+ -0.038635, -0.018204, -0.058563, 0.029419, 0.013344, 0.027618, 0.073669, -0.038361,
+ 0.080933, 0.044586, -0.013214, 0.022675, 0.084351, 0.081848, 0.027328, 0.043915,
+ 0.040771, 0.078918, 0.054443, -0.049652, 0.073547, 0.103882, 0.065918, 0.070923,
+ -0.037476, -0.011215, -0.021408, 0.094727, 0.042450, 0.032806, -0.064026, 0.023941,
+ 0.011780, 0.041260, -0.038818, 0.079163, 0.079468, 0.053680, 0.047150, 0.003571,
+ 0.054840, 0.045929, -0.041382, -0.033539, 0.069153, 0.046234, 0.119263, -0.006340,
+ -0.050323, 0.030212, 0.069092, 0.045441, 0.096313, -0.024628, -0.088745, 0.009033,
+ -0.016830, 0.028534, -0.042755, -0.031921, 0.013611, -0.029251, -0.051483, -0.005848,
+ -0.032837, -0.058136, 0.075989, -0.008125, 0.108765, -0.004745, -0.003422, 0.079590,
+ 0.090515, -0.019196, -0.006786, 0.059479, -0.041168, 0.093445, 0.075439, -0.025055,
+ 0.067139, 0.011734, 0.031586, 0.029587, 0.098267, 0.025848, 0.095276, 0.003189,
+ 0.105408, 0.018799, -0.102478, 0.033813, 0.004272, 0.020477, 0.033142, 0.009727,
+ -0.021393, 0.120300, 0.088684, -0.037842, -0.094177, 0.017944, 0.020126, -0.002304,
+ -0.016006, 0.018112, 0.072693, -0.072021, -0.171265, -0.053528, -0.093201, 0.024124,
+ -0.050476, -0.023422, -0.071167, 0.046478, 0.034607, 0.076904, 0.013077, -0.082031,
+ 0.091858, -0.001575, 0.083801, 0.078003, 0.019119, -0.004967, 0.027298, 0.027740,
+ 0.032623, 0.048370, 0.029099, 0.093201, 0.049957, -0.007191, 0.059631, 0.008659,
+ 0.042725, -0.009369, 0.089417, 0.074951, -0.024704, 0.005344, 0.123840, 0.080322,
+ 0.096375, 0.070312, -0.010399, 0.033203, -0.009743, -0.030045, -0.039520, 0.042023,
+ -0.017441, 0.073486, 0.049500, -0.039734, 0.009811, 0.093262, -0.069641, 0.099365,
+ -0.010414, 0.048859, 0.099182, -0.007256, -0.023941, -0.021393, -0.005703, 0.025055,
+ 0.054535, 0.093384, -0.033661, 0.073242, 0.055023, 0.037170, -0.009300, 0.048615,
+ 0.019150, 0.019409, -0.080688, -0.050049, 0.104126, -0.023193, 0.044708, 0.111816,
+ 0.061584, 0.042755, -0.013863, -0.008385, -0.039703, 0.070618, -0.016922, -0.040833,
+ 0.051178, -0.060333, -0.004368, -0.009827, 0.051544, 0.072083, 0.068176, 0.148071,
+ 0.159424, 0.017578, 0.089905, -0.006794, 0.066101, -0.051117, 0.088684, -0.002989,
+ -0.066895, 0.089844, 0.012131, -0.020203, 0.011230, 0.000327, 0.073669, 0.060669,
+ 0.091064, 0.075989, 0.051971, 0.045044, 0.033875, 0.040466, -0.029449, 0.128418,
+ -0.000229, -0.026901, 0.052063, 0.000995, -0.032532, 0.105896, -0.001241, 0.114075,
+ 0.047607, 0.090332, 0.063660, 0.016495, 0.124817, 0.090942, 0.021545, 0.007164,
+ 0.074890, 0.118347, 0.047394, 0.052856, 0.104980, 0.009384, 0.034363, 0.019073,
+ 0.072388, -0.013313, 0.119141, 0.021255, 0.103210, 0.058319, 0.186035, -0.010818,
+ 0.037109, -0.044037, -0.075989, -0.001281, 0.017899, 0.030701, -0.080261, 0.082703
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }
+ }
+ }
+
+ output[c] = max(0.0, sum); // ReLU
+ }
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}
diff --git a/cnn_v2/shaders/cnn_v2_layer_2.wgsl b/cnn_v2/shaders/cnn_v2_layer_2.wgsl
new file mode 100644
index 0000000..2f9836a
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_layer_2.wgsl
@@ -0,0 +1,156 @@
+// CNN v2 Layer 2 - Auto-generated
+// Kernel: 3×3, In: 12, Out: 4
+
+const KERNEL_SIZE: u32 = 3u;
+const IN_CHANNELS: u32 = 12u;
+const OUT_CHANNELS: u32 = 4u;
+const KERNEL_RADIUS: i32 = 1;
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, 432> = array(
+ 0.030212, -0.041351, 0.053864, -0.025635, 0.099976, -0.016830, -0.068665, 0.112488,
+ -0.069824, 0.030197, 0.020142, 0.101807, 0.061920, 0.022415, -0.025864, -0.056366,
+ 0.085571, -0.053650, 0.109802, 0.129272, 0.023438, 0.087341, 0.066284, 0.037079,
+ -0.067566, 0.021530, -0.046814, 0.029343, -0.028534, 0.047150, -0.079346, -0.022675,
+ -0.019669, -0.024185, 0.029587, 0.068970, 0.108826, 0.050598, -0.072144, 0.083008,
+ -0.002201, 0.006275, 0.056396, 0.001884, 0.097168, -0.028503, -0.002499, 0.008919,
+ -0.013771, -0.017502, -0.033478, 0.105530, 0.032898, 0.068726, -0.036285, -0.021011,
+ -0.018250, 0.073914, 0.024277, 0.061066, 0.008682, -0.022766, 0.074219, 0.094421,
+ 0.050903, 0.072571, 0.117493, -0.033234, 0.067993, -0.008049, 0.046997, -0.064209,
+ -0.381104, 0.107788, -0.213867, 0.145142, 0.514160, 0.407715, -0.317871, 0.249023,
+ 0.055634, -0.006294, -0.067444, 0.025131, 0.012939, -0.074158, -0.013741, -0.033020,
+ 0.026871, -0.007671, 0.089661, -0.003016, 0.029007, -0.038483, 0.045044, 0.104065,
+ 0.077148, 0.092468, -0.090027, -0.048126, 0.096863, -0.088013, 0.009483, 0.075012,
+ -0.076843, -0.085449, -0.066040, 0.019165, -0.019958, 0.083496, 0.069275, -0.019714,
+ 0.027786, -0.042389, 0.054718, 0.010635, -0.071777, 0.029282, -0.003605, 0.113770,
+ 0.080994, 0.106079, 0.047333, -0.013733, 0.034760, 0.099365, -0.020813, 0.095886,
+ 0.052490, -0.049194, 0.047394, 0.072510, -0.030930, -0.003782, -0.038025, -0.019318,
+ -0.047852, -0.043915, 0.026810, -0.041138, 0.038422, 0.009605, -0.080688, -0.019653,
+ 0.075256, -0.013817, -0.022400, 0.050629, 0.048462, 0.072998, -0.009109, 0.070923,
+ 0.079895, 0.071350, 0.002869, 0.081543, 0.037231, 0.020767, -0.017929, 0.042328,
+ -0.075134, -0.010681, -0.009079, 0.057007, -0.040253, -0.025574, -0.041534, 0.105835,
+ -0.039703, 0.032104, 0.076050, 0.070923, -0.013046, -0.054108, -0.024582, -0.033997,
+ 0.092285, 0.000525, 0.114685, 0.036926, -0.419434, 0.087891, -0.187866, 0.128906,
+ 0.665527, 0.268311, -0.337891, 0.195557, 0.140503, 0.014465, -0.043671, 0.031677,
+ 0.073059, 0.085144, 0.014290, -0.046967, 0.033356, 0.004177, 0.102844, 0.015259,
+ 0.026627, -0.005032, 0.111694, -0.010590, 0.029816, 0.108154, -0.072327, 0.056213,
+ 0.022903, 0.053772, 0.084473, -0.059845, -0.032776, -0.000015, -0.093872, -0.085815,
+ 0.081604, 0.069336, 0.034149, -0.067322, -0.020859, 0.120911, 0.077209, -0.016388,
+ 0.050140, -0.045563, -0.046326, 0.032623, -0.005009, 0.008003, 0.109192, 0.086548,
+ 0.096558, 0.118530, 0.035034, 0.110352, -0.041748, 0.009178, 0.049957, 0.084839,
+ 0.042053, -0.069153, -0.024796, -0.094604, -0.047028, -0.053802, 0.024979, 0.049591,
+ -0.016373, -0.047607, -0.008797, -0.058868, 0.107178, 0.055695, 0.092407, 0.092346,
+ 0.053894, 0.054657, -0.039703, -0.073792, 0.041779, -0.044159, 0.099182, 0.037109,
+ 0.097778, 0.098206, -0.057831, -0.054016, -0.068604, -0.061584, -0.054382, 0.005268,
+ 0.096008, -0.007118, -0.063049, 0.059113, 0.076904, 0.045288, -0.055695, -0.052612,
+ -0.022110, 0.049103, 0.095276, 0.014572, 0.064819, 0.014671, 0.029800, 0.066284,
+ -0.383301, 0.071838, -0.207275, 0.099365, 0.640137, 0.393311, -0.334229, 0.275391,
+ -0.013977, -0.025269, -0.007065, -0.033478, -0.017349, 0.026764, 0.005192, 0.093384,
+ 0.014313, 0.018906, 0.006962, 0.094849, 0.005390, 0.101624, -0.041199, 0.026245,
+ 0.027588, 0.062408, 0.033356, -0.010826, 0.067993, -0.054199, 0.076416, 0.023315,
+ -0.002886, -0.112061, -0.041473, -0.012703, 0.016022, 0.010506, -0.021362, -0.037750,
+ 0.062927, 0.061920, 0.038177, -0.037201, -0.011620, 0.014015, -0.062164, -0.045441,
+ -0.063416, -0.040100, 0.035950, 0.045563, -0.017227, -0.060547, -0.017593, 0.111877,
+ 0.121521, 0.073853, 0.023331, -0.012428, 0.018478, -0.010948, 0.030716, 0.043427,
+ 0.003117, -0.069092, 0.038361, -0.053497, 0.039154, -0.085754, 0.012642, -0.051208,
+ 0.022934, 0.127197, 0.117920, 0.074036, 0.083313, -0.061951, 0.079224, 0.091248,
+ 0.009132, 0.069946, 0.123474, 0.130127, 0.118835, 0.020874, -0.045380, -0.000111,
+ 0.111206, 0.054688, 0.008995, 0.085693, 0.005562, 0.103088, -0.034698, 0.119934,
+ -0.067200, 0.065430, -0.021942, 0.089783, 0.033112, -0.025467, 0.040161, -0.052155,
+ -0.048920, 0.031250, 0.112549, 0.122192, 0.126587, 0.180908, 0.194946, 0.121704,
+ 0.217529, 0.224243, 0.269287, 0.222656, 0.288086, 0.035492, 0.066711, -0.046600,
+ 0.085144, 0.013855, -0.065979, -0.083252, -0.058289, 0.104126, 0.013702, -0.018188,
+ 0.036591, 0.099854, 0.056061, 0.151855, 0.062134, 0.133789, 0.084045, 0.095825,
+ 0.036987, 0.022308, 0.070923, 0.031036, 0.101868, 0.062347, 0.141235, 0.066650
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution
+ var output: array<f32, OUT_CHANNELS>;
+ for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }
+
+ // Accumulate: layer input channels (if layer_idx > 0)
+ let prev_channels = IN_CHANNELS - 8u;
+ for (var i: u32 = 0u; i < prev_channels; i++) {
+ let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
+ (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }
+ }
+ }
+
+ output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation
+ }
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}
diff --git a/cnn_v2/shaders/cnn_v2_layer_template.wgsl b/cnn_v2/shaders/cnn_v2_layer_template.wgsl
new file mode 100644
index 0000000..1bf6819
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_layer_template.wgsl
@@ -0,0 +1,68 @@
+// CNN v2 Layer Template (placeholder for generated shaders)
+// This file documents the structure - actual layers generated by export script
+
+// Example: Layer 0 (1×1 kernel, 8→16 channels)
+// const KERNEL_SIZE: u32 = 1u;
+// const IN_CHANNELS: u32 = 8u; // 7 features + bias
+// const OUT_CHANNELS: u32 = 16u;
+// const weights: array<f32, 128> = array(...);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution loop (example for generated code)
+ // var output: array<f32, OUT_CHANNELS>;
+ // for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
+ // var sum: f32 = 0.0;
+ // for (var ky: i32 = -radius; ky <= radius; ky++) {
+ // for (var kx: i32 = -radius; kx <= radius; kx++) {
+ // let sample_coord = coord + vec2<i32>(kx, ky);
+ // // Load static + prev layer, multiply weights, accumulate
+ // }
+ // }
+ // output[c] = max(0.0, sum); // ReLU
+ // }
+
+ // Placeholder output
+ textureStore(output_tex, coord, vec4<u32>(0u));
+}
diff --git a/cnn_v2/shaders/cnn_v2_static.wgsl b/cnn_v2/shaders/cnn_v2_static.wgsl
new file mode 100644
index 0000000..309e832
--- /dev/null
+++ b/cnn_v2/shaders/cnn_v2_static.wgsl
@@ -0,0 +1,75 @@
+// CNN v2 Static Features Compute Shader
+// Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias]
+// p0-p3: Parametric features from specified mip level (0=mip0, 1=mip1, 2=mip2, 3=mip3)
+// Note: Input image RGBD (mip0) fed separately to Layer 0
+//
+// TODO: Binary format should support arbitrary layout and ordering for feature vector (7D).
+// Current layout is hardcoded. Future versions should allow runtime-specified
+// feature combinations (e.g., [R, G, B, dx, dy, uv_x, bias] or custom encodings).
+
+struct StaticFeatureParams {
+ mip_level: u32,
+ padding0: u32,
+ padding1: u32,
+ padding2: u32,
+}
+
+@group(0) @binding(0) var input_tex: texture_2d<f32>;
+@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
+@group(0) @binding(2) var input_tex_mip2: texture_2d<f32>;
+@group(0) @binding(3) var depth_tex: texture_2d<f32>;
+@group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(5) var<uniform> params: StaticFeatureParams;
+@group(0) @binding(6) var linear_sampler: sampler;
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(input_tex);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Parametric features (p0-p3) - bilinear sample from specified mip level
+ // Use UV coordinates for bilinear interpolation
+ // Note: Use textureSampleLevel (not textureSample) in compute shaders
+ let uv = (vec2<f32>(coord) + 0.5) / vec2<f32>(dims);
+ var rgba: vec4<f32>;
+ if (params.mip_level == 0u) {
+ rgba = textureSampleLevel(input_tex, linear_sampler, uv, 0.0);
+ } else if (params.mip_level == 1u) {
+ rgba = textureSampleLevel(input_tex_mip1, linear_sampler, uv, 0.0);
+ } else if (params.mip_level == 2u) {
+ rgba = textureSampleLevel(input_tex_mip2, linear_sampler, uv, 0.0);
+ } else {
+ // Mip 3 or higher: use mip 2 as fallback
+ rgba = textureSampleLevel(input_tex_mip2, linear_sampler, uv, 0.0);
+ }
+
+ let p0 = rgba.r;
+ let p1 = rgba.g;
+ let p2 = rgba.b;
+ let p3 = textureLoad(depth_tex, coord, 0).r;
+
+ // UV coordinates (normalized [0,1], top-left origin - matches training)
+ let uv_x = f32(coord.x) / f32(dims.x);
+ let uv_y = f32(coord.y) / f32(dims.y);
+
+ // Multi-frequency position encoding
+ let sin20_y = sin(20.0 * uv_y);
+
+ // Bias dimension (always 1.0)
+ let bias = 1.0;
+
+ // Pack 8×f16 into 4×u32 (rgba32uint)
+ // [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
+ let packed = vec4<u32>(
+ pack2x16float(vec2<f32>(p0, p1)),
+ pack2x16float(vec2<f32>(p2, p3)),
+ pack2x16float(vec2<f32>(uv_x, uv_y)),
+ pack2x16float(vec2<f32>(sin20_y, bias))
+ );
+
+ textureStore(output_tex, coord, packed);
+}
diff --git a/cnn_v2/src/cnn_v2_effect.cc b/cnn_v2/src/cnn_v2_effect.cc
new file mode 100644
index 0000000..60538d4
--- /dev/null
+++ b/cnn_v2/src/cnn_v2_effect.cc
@@ -0,0 +1,497 @@
+// CNN v2 Effect Implementation
+
+#include "cnn_v2_effect.h"
+
+#if defined(USE_TEST_ASSETS)
+#include "test_assets.h"
+#else
+#include "generated/assets.h"
+#endif
+
+#include "gpu/bind_group_builder.h"
+#include "gpu/gpu.h"
+#include "util/asset_manager.h"
+#include "util/fatal_error.h"
+#include <cstring>
+
+CNNv2Effect::CNNv2Effect(const GpuContext& ctx)
+ : PostProcessEffect(ctx), static_pipeline_(nullptr),
+ static_bind_group_(nullptr), static_params_buffer_(nullptr),
+ static_features_tex_(nullptr), static_features_view_(nullptr),
+ linear_sampler_(nullptr), layer_pipeline_(nullptr),
+ weights_buffer_(nullptr), input_mip_tex_(nullptr),
+ current_input_view_(nullptr), blend_amount_(1.0f), mip_level_(0),
+ initialized_(false) {
+ std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
+}
+
+CNNv2Effect::CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params)
+ : PostProcessEffect(ctx), static_pipeline_(nullptr),
+ static_bind_group_(nullptr), static_params_buffer_(nullptr),
+ static_features_tex_(nullptr), static_features_view_(nullptr),
+ linear_sampler_(nullptr), layer_pipeline_(nullptr),
+ weights_buffer_(nullptr), input_mip_tex_(nullptr),
+ current_input_view_(nullptr), blend_amount_(params.blend_amount),
+ mip_level_(0), initialized_(false) {
+ std::memset(input_mip_view_, 0, sizeof(input_mip_view_));
+}
+
+CNNv2Effect::~CNNv2Effect() {
+ cleanup();
+}
+
+void CNNv2Effect::init(MainSequence* demo) {
+ (void)demo;
+ if (initialized_)
+ return;
+
+ load_weights();
+ create_textures();
+ create_pipelines();
+
+ initialized_ = true;
+}
+
+void CNNv2Effect::resize(int width, int height) {
+ PostProcessEffect::resize(width, height);
+ cleanup();
+ create_textures();
+ create_pipelines();
+}
+
+void CNNv2Effect::load_weights() {
+ // Load binary weights asset
+ size_t weights_size = 0;
+ const uint8_t* weights_data =
+ (const uint8_t*)GetAsset(AssetId::ASSET_WEIGHTS_CNN_V2, &weights_size);
+
+ if (!weights_data || weights_size < 20) {
+ // Weights not available - effect will skip
+ return;
+ }
+
+ // Parse header
+ const uint32_t* header = (const uint32_t*)weights_data;
+ uint32_t magic = header[0];
+ uint32_t version = header[1];
+ uint32_t num_layers = header[2];
+ uint32_t total_weights = header[3];
+
+ FATAL_CHECK(magic != 0x324e4e43, "Invalid CNN v2 weights magic\n"); // 'CNN2'
+
+ // Support both version 1 (16-byte header) and version 2 (20-byte header with
+ // mip_level)
+ // TODO: Version 3 should include feature descriptor for arbitrary
+ // layout/ordering
+ if (version == 1) {
+ mip_level_ = 0; // Default for v1
+ } else if (version == 2) {
+ mip_level_ = header[4];
+ } else {
+ FATAL_ERROR("Unsupported CNN v2 weights version: %u\n", version);
+ }
+
+ // Parse layer info (20 bytes per layer)
+ // Offset depends on version: v1=16 bytes (4 u32), v2=20 bytes (5 u32)
+ const uint32_t header_u32_count = (version == 1) ? 4 : 5;
+ const uint32_t* layer_data = header + header_u32_count;
+ for (uint32_t i = 0; i < num_layers; ++i) {
+ LayerInfo info;
+ info.kernel_size = layer_data[i * 5 + 0];
+ info.in_channels = layer_data[i * 5 + 1];
+ info.out_channels = layer_data[i * 5 + 2];
+ info.weight_offset = layer_data[i * 5 + 3];
+ info.weight_count = layer_data[i * 5 + 4];
+ layer_info_.push_back(info);
+ }
+
+ // Create GPU storage buffer for weights (skip header + layer info, upload
+ // only weights)
+ size_t header_size = 20; // 5 u32
+ size_t layer_info_size = 20 * num_layers; // 5 u32 per layer
+ size_t weights_offset = header_size + layer_info_size;
+ size_t weights_only_size = weights_size - weights_offset;
+
+ WGPUBufferDescriptor buffer_desc = {};
+ buffer_desc.size = weights_only_size;
+ buffer_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst;
+ buffer_desc.mappedAtCreation = false;
+
+ weights_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &buffer_desc);
+
+ // Upload only weights (skip header + layer info)
+ wgpuQueueWriteBuffer(ctx_.queue, weights_buffer_, 0,
+ weights_data + weights_offset, weights_only_size);
+
+ // Create uniform buffers for layer params (one per layer)
+ for (uint32_t i = 0; i < num_layers; ++i) {
+ WGPUBufferDescriptor params_desc = {};
+ params_desc.size = sizeof(LayerParams);
+ params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
+ params_desc.mappedAtCreation = false;
+
+ WGPUBuffer buf = wgpuDeviceCreateBuffer(ctx_.device, &params_desc);
+ layer_params_buffers_.push_back(buf);
+ }
+}
+
+void CNNv2Effect::create_textures() {
+ // Static features texture (8×f16 packed as 4×u32)
+ TextureWithView static_tex = gpu_create_storage_texture_2d(
+ ctx_.device, width_, height_, WGPUTextureFormat_RGBA32Uint);
+ static_features_tex_ = static_tex.texture;
+ static_features_view_ = static_tex.view;
+
+ // Input texture with mips (for multi-scale features)
+ TextureWithView input_mip = gpu_create_texture_2d(
+ ctx_.device, width_, height_, WGPUTextureFormat_RGBA8Unorm,
+ (WGPUTextureUsage)(WGPUTextureUsage_TextureBinding |
+ WGPUTextureUsage_CopyDst),
+ 3);
+ input_mip_tex_ = input_mip.texture;
+
+ for (int i = 0; i < 3; ++i) {
+ input_mip_view_[i] =
+ gpu_create_mip_view(input_mip_tex_, WGPUTextureFormat_RGBA8Unorm, i);
+ }
+
+ // Create 2 layer textures (ping-pong buffers for intermediate results)
+ // Each stores 8×f16 channels packed as 4×u32
+ for (int i = 0; i < 2; ++i) {
+ TextureWithView layer = gpu_create_storage_texture_2d(
+ ctx_.device, width_, height_, WGPUTextureFormat_RGBA32Uint);
+ layer_textures_.push_back(layer.texture);
+ layer_views_.push_back(layer.view);
+ }
+
+ // Create uniform buffer for static feature params
+ WGPUBufferDescriptor params_desc = {};
+ params_desc.size = sizeof(StaticFeatureParams);
+ params_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
+ params_desc.mappedAtCreation = false;
+ static_params_buffer_ = wgpuDeviceCreateBuffer(ctx_.device, &params_desc);
+}
+
+void CNNv2Effect::create_pipelines() {
+ // Create linear sampler for bilinear interpolation
+ WGPUSamplerDescriptor sampler_desc = {};
+ sampler_desc.addressModeU = WGPUAddressMode_ClampToEdge;
+ sampler_desc.addressModeV = WGPUAddressMode_ClampToEdge;
+ sampler_desc.addressModeW = WGPUAddressMode_ClampToEdge;
+ sampler_desc.magFilter = WGPUFilterMode_Linear;
+ sampler_desc.minFilter = WGPUFilterMode_Linear;
+ sampler_desc.mipmapFilter = WGPUMipmapFilterMode_Linear;
+ sampler_desc.lodMinClamp = 0.0f;
+ sampler_desc.lodMaxClamp = 32.0f;
+ sampler_desc.maxAnisotropy = 1;
+
+ linear_sampler_ = wgpuDeviceCreateSampler(ctx_.device, &sampler_desc);
+
+ // Static features compute pipeline
+ size_t shader_size = 0;
+ const char* static_code =
+ (const char*)GetAsset(AssetId::ASSET_SHADER_CNN_V2_STATIC, &shader_size);
+
+ if (!static_code || shader_size == 0) {
+ // Shader not available (e.g., in test mode) - skip pipeline creation
+ return;
+ }
+
+ WGPUShaderSourceWGSL wgsl_src = {};
+ wgsl_src.chain.sType = WGPUSType_ShaderSourceWGSL;
+ wgsl_src.code = str_view(static_code);
+
+ WGPUShaderModuleDescriptor shader_desc = {};
+ shader_desc.nextInChain = &wgsl_src.chain;
+
+ // Create bind group layout for static features compute
+ // Bindings: 0=input_tex, 1=input_mip1, 2=input_mip2, 3=depth_tex, 4=output,
+ // 5=params, 6=linear_sampler
+ WGPUBindGroupLayout static_bgl =
+ BindGroupLayoutBuilder()
+ .texture(0, WGPUShaderStage_Compute)
+ .texture(1, WGPUShaderStage_Compute)
+ .texture(2, WGPUShaderStage_Compute)
+ .texture(3, WGPUShaderStage_Compute)
+ .storage_texture(4, WGPUShaderStage_Compute,
+ WGPUTextureFormat_RGBA32Uint)
+ .uniform(5, WGPUShaderStage_Compute, sizeof(StaticFeatureParams))
+ .sampler(6, WGPUShaderStage_Compute)
+ .build(ctx_.device);
+
+ // Update pipeline layout
+ WGPUPipelineLayoutDescriptor pl_desc = {};
+ pl_desc.bindGroupLayoutCount = 1;
+ pl_desc.bindGroupLayouts = &static_bgl;
+ WGPUPipelineLayout pipeline_layout =
+ wgpuDeviceCreatePipelineLayout(ctx_.device, &pl_desc);
+
+ // Recreate pipeline with proper layout
+ WGPUComputePipelineDescriptor pipeline_desc2 = {};
+ pipeline_desc2.compute.module =
+ wgpuDeviceCreateShaderModule(ctx_.device, &shader_desc);
+ pipeline_desc2.compute.entryPoint = str_view("main");
+ pipeline_desc2.layout = pipeline_layout;
+
+ if (static_pipeline_)
+ wgpuComputePipelineRelease(static_pipeline_);
+ static_pipeline_ =
+ wgpuDeviceCreateComputePipeline(ctx_.device, &pipeline_desc2);
+
+ wgpuShaderModuleRelease(pipeline_desc2.compute.module);
+ wgpuPipelineLayoutRelease(pipeline_layout);
+ wgpuBindGroupLayoutRelease(static_bgl);
+
+ // CNN layer compute pipeline (storage buffer version)
+ if (layer_info_.empty())
+ return; // No weights loaded
+
+ size_t layer_shader_size = 0;
+ const char* layer_code = (const char*)GetAsset(
+ AssetId::ASSET_SHADER_CNN_V2_COMPUTE, &layer_shader_size);
+
+ if (!layer_code || layer_shader_size == 0)
+ return;
+
+ WGPUShaderSourceWGSL layer_wgsl = {};
+ layer_wgsl.chain.sType = WGPUSType_ShaderSourceWGSL;
+ layer_wgsl.code = str_view(layer_code);
+
+ WGPUShaderModuleDescriptor layer_shader_desc = {};
+ layer_shader_desc.nextInChain = &layer_wgsl.chain;
+
+ WGPUShaderModule layer_module =
+ wgpuDeviceCreateShaderModule(ctx_.device, &layer_shader_desc);
+ if (!layer_module)
+ return;
+
+ // Create bind group layout for layer compute
+ // 0=static_features, 1=layer_input, 2=output, 3=weights, 4=params,
+ // 5=original_input
+ WGPUBindGroupLayout layer_bgl =
+ BindGroupLayoutBuilder()
+ .uint_texture(0, WGPUShaderStage_Compute)
+ .uint_texture(1, WGPUShaderStage_Compute)
+ .storage_texture(2, WGPUShaderStage_Compute,
+ WGPUTextureFormat_RGBA32Uint)
+ .storage(3, WGPUShaderStage_Compute)
+ .uniform(4, WGPUShaderStage_Compute, sizeof(LayerParams))
+ .texture(5, WGPUShaderStage_Compute)
+ .build(ctx_.device);
+
+ WGPUPipelineLayoutDescriptor layer_pl_desc = {};
+ layer_pl_desc.bindGroupLayoutCount = 1;
+ layer_pl_desc.bindGroupLayouts = &layer_bgl;
+
+ WGPUPipelineLayout layer_pipeline_layout =
+ wgpuDeviceCreatePipelineLayout(ctx_.device, &layer_pl_desc);
+
+ WGPUComputePipelineDescriptor layer_pipeline_desc = {};
+ layer_pipeline_desc.compute.module = layer_module;
+ layer_pipeline_desc.compute.entryPoint = str_view("main");
+ layer_pipeline_desc.layout = layer_pipeline_layout;
+
+ layer_pipeline_ =
+ wgpuDeviceCreateComputePipeline(ctx_.device, &layer_pipeline_desc);
+
+ wgpuShaderModuleRelease(layer_module);
+ wgpuPipelineLayoutRelease(layer_pipeline_layout);
+ wgpuBindGroupLayoutRelease(layer_bgl);
+}
+
+void CNNv2Effect::update_bind_group(WGPUTextureView input_view) {
+ if (!static_pipeline_)
+ return;
+
+ // Cache input view
+ current_input_view_ = input_view;
+
+ // Release old bind group
+ if (static_bind_group_) {
+ wgpuBindGroupRelease(static_bind_group_);
+ static_bind_group_ = nullptr;
+ }
+
+ // Create bind group for static features compute (manual for storage texture
+ // binding)
+ WGPUBindGroupEntry bg_entries[7] = {};
+ bg_entries[0].binding = 0;
+ bg_entries[0].textureView = input_view;
+ bg_entries[1].binding = 1;
+ bg_entries[1].textureView = input_mip_view_[0];
+ bg_entries[2].binding = 2;
+ bg_entries[2].textureView =
+ input_mip_view_[1] ? input_mip_view_[1] : input_mip_view_[0];
+ bg_entries[3].binding = 3;
+ bg_entries[3].textureView = input_view;
+ bg_entries[4].binding = 4;
+ bg_entries[4].textureView = static_features_view_;
+ bg_entries[5].binding = 5;
+ bg_entries[5].buffer = static_params_buffer_;
+ bg_entries[5].size = sizeof(StaticFeatureParams);
+ bg_entries[6].binding = 6;
+ bg_entries[6].sampler = linear_sampler_;
+
+ WGPUBindGroupLayout layout =
+ wgpuComputePipelineGetBindGroupLayout(static_pipeline_, 0);
+ WGPUBindGroupDescriptor bg_desc = {};
+ bg_desc.layout = layout;
+ bg_desc.entryCount = 7;
+ bg_desc.entries = bg_entries;
+ static_bind_group_ = wgpuDeviceCreateBindGroup(ctx_.device, &bg_desc);
+ wgpuBindGroupLayoutRelease(layout);
+
+ // Create layer bind groups
+ if (!layer_pipeline_ || layer_info_.empty())
+ return;
+
+ // Release old layer bind groups
+ for (auto bg : layer_bind_groups_) {
+ wgpuBindGroupRelease(bg);
+ }
+ layer_bind_groups_.clear();
+
+ // Get bind group layout from layer pipeline
+ WGPUBindGroupLayout layer_bgl =
+ wgpuComputePipelineGetBindGroupLayout(layer_pipeline_, 0);
+
+ // Create bind group for each layer
+ for (size_t i = 0; i < layer_info_.size(); ++i) {
+ WGPUTextureView layer_input =
+ (i == 0) ? static_features_view_ : layer_views_[i % 2];
+
+ WGPUBindGroup layer_bg =
+ BindGroupBuilder()
+ .texture(0, static_features_view_)
+ .texture(1, layer_input)
+ .texture(2, layer_views_[(i + 1) % 2])
+ .buffer(3, weights_buffer_, wgpuBufferGetSize(weights_buffer_))
+ .buffer(4, layer_params_buffers_[i], sizeof(LayerParams))
+ .texture(5, input_view)
+ .build(ctx_.device, layer_bgl);
+
+ layer_bind_groups_.push_back(layer_bg);
+ }
+
+ wgpuBindGroupLayoutRelease(layer_bgl);
+}
+
+void CNNv2Effect::compute(WGPUCommandEncoder encoder,
+ const CommonPostProcessUniforms& uniforms) {
+ if (!initialized_ || !static_pipeline_ || !static_bind_group_)
+ return;
+
+ float effective_blend = blend_amount_;
+ if (beat_modulated_) {
+ effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
+ }
+
+ // Update static feature params
+ StaticFeatureParams static_params;
+ static_params.mip_level = mip_level_;
+ static_params.padding[0] = 0;
+ static_params.padding[1] = 0;
+ static_params.padding[2] = 0;
+ wgpuQueueWriteBuffer(ctx_.queue, static_params_buffer_, 0, &static_params,
+ sizeof(static_params));
+
+ // Pass 1: Compute static features
+ WGPUComputePassEncoder pass =
+ wgpuCommandEncoderBeginComputePass(encoder, nullptr);
+
+ wgpuComputePassEncoderSetPipeline(pass, static_pipeline_);
+ wgpuComputePassEncoderSetBindGroup(pass, 0, static_bind_group_, 0, nullptr);
+
+ // Dispatch workgroups (8×8 threads per group)
+ uint32_t workgroups_x = (width_ + 7) / 8;
+ uint32_t workgroups_y = (height_ + 7) / 8;
+ wgpuComputePassEncoderDispatchWorkgroups(pass, workgroups_x, workgroups_y, 1);
+
+ wgpuComputePassEncoderEnd(pass);
+ wgpuComputePassEncoderRelease(pass);
+
+ // Execute CNN layer passes
+ if (!layer_pipeline_ || layer_bind_groups_.empty())
+ return;
+
+ // Update layer params (each layer has own buffer)
+ for (size_t i = 0; i < layer_info_.size(); ++i) {
+ const LayerInfo& info = layer_info_[i];
+
+ LayerParams params;
+ params.kernel_size = info.kernel_size;
+ params.in_channels = info.in_channels;
+ params.out_channels = info.out_channels;
+ params.weight_offset = info.weight_offset;
+ params.is_output_layer = (i == layer_info_.size() - 1) ? 1 : 0;
+ params.blend_amount = effective_blend;
+ params.is_layer_0 = (i == 0) ? 1 : 0;
+
+ wgpuQueueWriteBuffer(ctx_.queue, layer_params_buffers_[i], 0, &params,
+ sizeof(params));
+
+ WGPUComputePassEncoder layer_pass =
+ wgpuCommandEncoderBeginComputePass(encoder, nullptr);
+
+ wgpuComputePassEncoderSetPipeline(layer_pass, layer_pipeline_);
+ wgpuComputePassEncoderSetBindGroup(layer_pass, 0, layer_bind_groups_[i], 0,
+ nullptr);
+
+ wgpuComputePassEncoderDispatchWorkgroups(layer_pass, workgroups_x,
+ workgroups_y, 1);
+
+ wgpuComputePassEncoderEnd(layer_pass);
+ wgpuComputePassEncoderRelease(layer_pass);
+ }
+}
+
+void CNNv2Effect::render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) {
+ (void)pass;
+ (void)uniforms;
+ // Compute-only effect, rendering is done by default composite pass
+}
+
+void CNNv2Effect::cleanup() {
+ if (static_features_view_)
+ wgpuTextureViewRelease(static_features_view_);
+ if (static_features_tex_)
+ wgpuTextureRelease(static_features_tex_);
+ if (static_bind_group_)
+ wgpuBindGroupRelease(static_bind_group_);
+ if (static_params_buffer_)
+ wgpuBufferRelease(static_params_buffer_);
+ if (static_pipeline_)
+ wgpuComputePipelineRelease(static_pipeline_);
+ if (linear_sampler_)
+ wgpuSamplerRelease(linear_sampler_);
+
+ if (layer_pipeline_)
+ wgpuComputePipelineRelease(layer_pipeline_);
+ if (weights_buffer_)
+ wgpuBufferRelease(weights_buffer_);
+ for (auto buf : layer_params_buffers_)
+ wgpuBufferRelease(buf);
+ layer_params_buffers_.clear();
+
+ for (int i = 0; i < 3; ++i) {
+ if (input_mip_view_[i])
+ wgpuTextureViewRelease(input_mip_view_[i]);
+ }
+ if (input_mip_tex_)
+ wgpuTextureRelease(input_mip_tex_);
+
+ for (auto view : layer_views_)
+ wgpuTextureViewRelease(view);
+ for (auto tex : layer_textures_)
+ wgpuTextureRelease(tex);
+ for (auto bg : layer_bind_groups_)
+ wgpuBindGroupRelease(bg);
+
+ layer_views_.clear();
+ layer_textures_.clear();
+ layer_bind_groups_.clear();
+ layer_info_.clear();
+
+ initialized_ = false;
+}
diff --git a/cnn_v2/src/cnn_v2_effect.h b/cnn_v2/src/cnn_v2_effect.h
new file mode 100644
index 0000000..7960b4f
--- /dev/null
+++ b/cnn_v2/src/cnn_v2_effect.h
@@ -0,0 +1,89 @@
+// CNN v2 Effect - Parametric Static Features
+// Multi-pass post-processing with 7D feature input
+// Supports per-layer kernel sizes (e.g., 1×1, 3×3, 5×5)
+
+#pragma once
+#include "gpu/effect.h"
+#include <vector>
+
+struct CNNv2EffectParams {
+ float blend_amount = 1.0f;
+};
+
+class CNNv2Effect : public PostProcessEffect {
+ public:
+ explicit CNNv2Effect(const GpuContext& ctx);
+ explicit CNNv2Effect(const GpuContext& ctx, const CNNv2EffectParams& params);
+ ~CNNv2Effect();
+
+ void init(MainSequence* demo) override;
+ void resize(int width, int height) override;
+ void compute(WGPUCommandEncoder encoder,
+ const CommonPostProcessUniforms& uniforms) override;
+ void render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) override;
+ void update_bind_group(WGPUTextureView input_view) override;
+
+ void set_beat_modulation(bool enabled, float scale = 1.0f) {
+ beat_modulated_ = enabled;
+ beat_scale_ = scale;
+ }
+
+ private:
+ struct LayerInfo {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t weight_count;
+ };
+
+ struct LayerParams {
+ uint32_t kernel_size;
+ uint32_t in_channels;
+ uint32_t out_channels;
+ uint32_t weight_offset;
+ uint32_t is_output_layer;
+ float blend_amount;
+ uint32_t is_layer_0;
+ };
+
+ struct StaticFeatureParams {
+ uint32_t mip_level;
+ uint32_t padding[3];
+ };
+
+ void create_textures();
+ void create_pipelines();
+ void load_weights();
+ void cleanup();
+
+ // Static features compute
+ WGPUComputePipeline static_pipeline_;
+ WGPUBindGroup static_bind_group_;
+ WGPUBuffer static_params_buffer_;
+ WGPUTexture static_features_tex_;
+ WGPUTextureView static_features_view_;
+ WGPUSampler linear_sampler_;
+
+ // CNN layers (storage buffer architecture)
+ WGPUComputePipeline layer_pipeline_; // Single pipeline for all layers
+ WGPUBuffer weights_buffer_; // Storage buffer for weights
+ std::vector<WGPUBuffer>
+ layer_params_buffers_; // Uniform buffers (one per layer)
+ std::vector<LayerInfo> layer_info_; // Layer metadata
+ std::vector<WGPUBindGroup> layer_bind_groups_; // Per-layer bind groups
+ std::vector<WGPUTexture> layer_textures_; // Ping-pong buffers
+ std::vector<WGPUTextureView> layer_views_;
+
+ // Input mips
+ WGPUTexture input_mip_tex_;
+ WGPUTextureView input_mip_view_[3];
+ WGPUTextureView current_input_view_;
+
+ float blend_amount_ = 1.0f;
+ bool beat_modulated_ = false;
+ float beat_scale_ = 1.0f;
+ uint32_t mip_level_ = 0;
+ bool initialized_;
+};
diff --git a/cnn_v2/tools/cnn_v2_test/README.md b/cnn_v2/tools/cnn_v2_test/README.md
new file mode 100644
index 0000000..d41a00f
--- /dev/null
+++ b/cnn_v2/tools/cnn_v2_test/README.md
@@ -0,0 +1,251 @@
+# CNN v2 Testing Tool
+
+WebGPU-based browser tool for testing trained CNN v2 weights.
+
+---
+
+## Features
+
+- Drag-drop PNG images and `.bin` weights (or click to browse)
+- Real-time CNN inference with WebGPU compute shaders
+- View modes: CNN output, original input, difference (×10)
+- Adjustable blend amount and depth
+- Data-driven pipeline (supports variable layer count)
+- GPU timing display
+- **Left Panel:** Weights info + kernel visualization (1px/weight, all layers)
+- **Right Panel:** Layer activation viewer with 4-channel split + 4× zoom
+
+---
+
+## Requirements
+
+- Browser with WebGPU support:
+ - Chrome/Edge 113+ (enable `chrome://flags/#enable-unsafe-webgpu` if needed)
+ - Safari 18+ (macOS Ventura+)
+- Trained CNN v2 weights in binary format (`.bin`)
+- Test images (PNG format)
+
+---
+
+## Usage
+
+### 1. Open Tool
+
+```bash
+open tools/cnn_v2_test/index.html
+```
+
+Or use a local server to avoid CORS:
+```bash
+python3 -m http.server 8000
+# Open http://localhost:8000/tools/cnn_v2_test/
+```
+
+### 2. Load Data
+
+1. **Drop `.bin` weights** into left sidebar zone (or click to browse)
+2. **Drop PNG image** anywhere in center canvas area
+3. CNN runs automatically when both loaded
+
+### 3. Layout
+
+**Left Sidebar:**
+- Weights drop zone (click or drag-drop `.bin` files)
+- Weights info panel (layer specs, ranges, file size)
+- Weights visualization (click Layer 0/1/2 buttons)
+ - 1 pixel per weight, all input channels horizontally
+ - Output channels (Out 0-3) stacked vertically
+
+**Center Canvas:**
+- Main output view (CNN result, original, or diff)
+- Keyboard: `SPACE` = original, `D` = diff (×10)
+
+**Right Sidebar:**
+- Layer selection buttons (Static 0-3/4-7, Layer 0/1/2)
+- 4 small activation views (Ch0/1/2/3) in a row
+- Large zoom view below (4× magnification, follows mouse)
+
+**Header Controls:**
+- **Blend:** Mix between original (0.0) and CNN output (1.0)
+- **Depth:** Uniform depth value for all pixels (0.0–1.0)
+- **View:** Current display mode
+
+**Footer:**
+- Status: GPU timing (ms), image dimensions, view mode
+- Console: Timestamped event log (file loads, errors)
+
+---
+
+## Preparing Test Data
+
+### Export Weights
+
+```bash
+# From trained checkpoint
+./training/export_cnn_v2_weights.py \
+ checkpoints/checkpoint_epoch_100.pth \
+ --output-weights tools/cnn_v2_test/test_weights.bin
+```
+
+Binary format: 16-byte header + 20 bytes per layer + f16 weights (~3.2 KB for 3-layer model)
+
+### Test Images
+
+Use training images or any PNG:
+```bash
+# Copy test image
+cp training/input/test.png tools/cnn_v2_test/
+```
+
+**Note:** Grayscale images automatically converted to RGB.
+
+---
+
+## Validation
+
+### Visual Comparison
+
+Compare browser output with C++ tool:
+
+```bash
+# Generate C++ output
+./build/cnn_test training/input/test.png /tmp/cpp_output.png
+
+# Load same image in browser tool
+# Visually compare outputs
+```
+
+### GPU Timing
+
+Expected performance:
+- 512×512: ~1-2 ms (integrated GPU)
+- 1024×1024: ~3-5 ms
+- 1920×1080: ~5-8 ms
+
+Slower than expected? Check:
+- WebGPU enabled in browser
+- Dedicated GPU selected (if available)
+- No background tabs consuming GPU
+
+---
+
+## Troubleshooting
+
+### "WebGPU not supported"
+
+- Update browser to latest version
+- Enable WebGPU flag: `chrome://flags/#enable-unsafe-webgpu`
+- Try Safari 18+ (native WebGPU on macOS)
+
+### "Invalid .bin file"
+
+- Check magic number: `hexdump -C weights.bin | head`
+- Should start with: `43 4e 4e 32` ('CNN2')
+- Re-export weights: `./training/export_cnn_v2_weights.py`
+
+### Black output / incorrect colors
+
+- Check blend slider (set to 1.0 for full CNN output)
+- Verify training converged (loss < 0.01)
+- Compare with C++ tool output
+
+### Shader compilation errors
+
+Open browser console (F12) for detailed errors. Common issues:
+- Image too large (>4096×4096 not tested)
+- Unsupported texture format (rare on modern GPUs)
+
+---
+
+## Architecture
+
+**Pipeline:**
+1. **Static Features Pass** - Generate 8D features (RGBD, UV, sin, bias)
+2. **CNN Layer Passes** - Compute N layers with ping-pong textures
+3. **Display Pass** - Unpack and render with view mode
+
+**Textures:**
+- Input: RGBA8 (original image)
+- Depth: R32F (uniform depth)
+- Static features: RGBA32Uint (8×f16 packed)
+- Layer buffers: RGBA32Uint (ping-pong)
+
+**Data-Driven Execution:**
+- Layer count read from binary header
+- Per-layer params (kernel size, channels, offsets) from binary
+- Single CNN shader dispatched N times
+
+---
+
+## Implemented Features
+
+**✓ Weights Metadata Panel:**
+- Layer descriptions (kernel size, channels, weight count)
+- Weight statistics (min/max per layer)
+- File size and layer count
+
+**✓ Weights Visualization:**
+- Per-layer kernel heatmaps (1px/weight)
+- All input channels displayed horizontally
+- Output channels stacked vertically
+- Normalized grayscale display
+
+**✓ Layer Activation Viewer:**
+- Static features (8D split into 0-3 and 4-7 views)
+- All CNN layer outputs (Layer 0/1/2...)
+- 4-channel split view (grayscale per channel)
+- Mouse-driven 4× zoom view
+
+## TODO
+
+**Future Enhancements:**
+- Weight distribution histograms per layer
+- Activation statistics (min/max/mean overlay)
+- Side-by-side diff mode (browser vs C++ output)
+- Export rendered layers as PNG
+
+---
+
+## Extensions (v2+)
+
+Planned enhancements:
+
+**Variable Feature Count:**
+- Binary v2: Add `num_features` to header
+- Shader: Dynamic feature array or multiple textures
+
+**Multi-Scale Input (Mip Levels):**
+- Uncomment mip bindings in static shader
+- No binary format change needed
+
+**8-bit Quantized Weights:**
+- Binary version bump (format field already present)
+- Add quantization codepath in `get_weight()` function
+- 2× size reduction (~1.6 KB)
+
+**Pre-defined Test Images:**
+- Dropdown menu with training/input/*.png
+- Requires local file server
+
+---
+
+## Size
+
+- HTML structure: ~2 KB
+- CSS styling: ~2 KB
+- JavaScript logic: ~10 KB (includes zoom + weights viz)
+- Static shader: ~1 KB
+- CNN shader: ~3 KB
+- Display shader: ~1 KB
+- Layer viz shader: ~2 KB
+- Zoom shader: ~1 KB
+- **Total: ~22 KB** (single file, no dependencies)
+
+---
+
+## See Also
+
+- `doc/CNN_V2.md` - Architecture and design
+- `doc/HOWTO.md` - Training workflows
+- `training/export_cnn_v2_weights.py` - Binary format
+- `src/effects/cnn_v2_effect.cc` - C++ reference implementation
diff --git a/cnn_v2/tools/cnn_v2_test/index.html b/cnn_v2/tools/cnn_v2_test/index.html
new file mode 100644
index 0000000..84702d5
--- /dev/null
+++ b/cnn_v2/tools/cnn_v2_test/index.html
@@ -0,0 +1,2014 @@
+<!DOCTYPE html>
+<html lang="en">
+<!--
+ CNN v2 Testing Tool - WebGPU-based inference validator
+
+ Architecture:
+ - Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10*uv_x), bias (NOT a CNN layer)
+ - Layer 0: input RGBD (4D) + static (8D) = 12D → 4 channels
+ - Layer 1+: previous layer (4D) + static (8D) = 12D → 4 channels
+ - All CNN layers: uniform 12D input, 4D output (ping-pong buffer)
+
+ Naming convention (matches train_cnn_v2.py / .wgsl / .cc):
+ - UI shows: "Static 0-3", "Static 4-7", "Layer 0", "Layer 1", "Layer 2"
+ - weights.layers[] array: Layer 0 = weights.layers[0], Layer 1 = weights.layers[1]
+
+ Features:
+ - Input: PNG images or video files (MP4, WebM, etc.)
+ - Video playback: Play/Pause, frame-by-frame navigation (◄/► buttons)
+ - Video mode: Non-realtime processing (drops frames if CNN slower than playback)
+ - Side panel: .bin metadata display, weight statistics per layer
+ - Layer inspection: 4-channel grayscale split, intermediate layer visualization
+ - View modes: CNN output, original, diff (×10)
+ - Optimization: Layer viz updates only on pause/seek during video playback
+
+ WGSL Shader Reuse:
+ - CNN_SHADER (inference), STATIC_SHADER, LAYER_VIZ_SHADER are inline for single-file deployment
+ - Can extract to .wgsl files for: better IDE support, testing, cross-tool reuse
+ - Tradeoff: extraction needs fetch() or build step, breaks single-file portability
+ - C++ sync: manual (WGSL ≠ GLSL) but logic identical
+-->
+<head>
+ <meta charset="UTF-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>CNN v2 Testing Tool</title>
+ <link rel="stylesheet" href="../common/style.css">
+ <style>
+ body {
+ display: flex;
+ flex-direction: column;
+ height: 100vh;
+ }
+ .header {
+ padding: 16px;
+ border-bottom: 1px solid #404040;
+ gap: 24px;
+ }
+ h1 { font-size: 18px; }
+ .controls {
+ gap: 16px;
+ }
+ .control-group {
+ display: flex;
+ gap: 8px;
+ align-items: center;
+ }
+ .control-group label { font-size: 12px; }
+ input[type="range"] { width: 120px; }
+ input[type="number"] { width: 60px; padding: 4px; }
+ .drop-zone {
+ border: 3px dashed #606060;
+ padding: 20px;
+ text-align: center;
+ cursor: pointer;
+ transition: all 0.2s;
+ font-size: 13px;
+ font-weight: bold;
+ background: #252525;
+ border-radius: 6px;
+ color: #4a9eff;
+ }
+ button {
+ padding: 6px 12px;
+ font-size: 12px;
+ }
+ button:hover { border-color: #606060; background: #252525; }
+ video { display: none; }
+ .drop-zone:hover { border-color: #4a9eff; background: #2a3545; }
+ .drop-zone.active { border-color: #4a9eff; background: #1a2a3a; }
+ .drop-zone.error { border-color: #ff4a4a; background: #3a1a1a; }
+ .content {
+ flex: 1;
+ display: flex;
+ overflow: hidden;
+ gap: 1px;
+ background: #404040;
+ }
+ .left-sidebar {
+ width: 315px;
+ background: #2a2a2a;
+ overflow-y: auto;
+ display: flex;
+ flex-direction: column;
+ gap: 16px;
+ padding: 16px;
+ }
+ .main {
+ flex: 1;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ padding: 24px;
+ overflow: auto;
+ position: relative;
+ }
+ .video-controls-float {
+ position: absolute;
+ top: 16px;
+ left: 50%;
+ transform: translateX(-50%);
+ display: flex;
+ gap: 8px;
+ background: rgba(42, 42, 42, 0.95);
+ padding: 8px 12px;
+ border-radius: 4px;
+ border: 1px solid #404040;
+ z-index: 100;
+ }
+ .bottom-controls-float {
+ position: absolute;
+ bottom: 16px;
+ left: 50%;
+ transform: translateX(-50%);
+ display: flex;
+ gap: 16px;
+ align-items: center;
+ background: rgba(42, 42, 42, 0.95);
+ padding: 8px 16px;
+ border-radius: 4px;
+ border: 1px solid #404040;
+ z-index: 100;
+ }
+ .bottom-controls-float .control-group {
+ display: flex;
+ gap: 8px;
+ align-items: center;
+ }
+ .bottom-controls-float #videoControls {
+ display: flex;
+ gap: 8px;
+ align-items: center;
+ padding-right: 16px;
+ border-right: 1px solid #404040;
+ }
+ .main.drop-active::after {
+ content: 'Drop PNG/video here';
+ position: absolute;
+ inset: 24px;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ border: 3px dashed #4a9eff;
+ background: rgba(74, 158, 255, 0.1);
+ font-size: 24px;
+ color: #4a9eff;
+ pointer-events: none;
+ z-index: 10;
+ }
+ .sidebar {
+ width: 400px;
+ background: #2a2a2a;
+ overflow-y: auto;
+ display: flex;
+ flex-direction: column;
+ gap: 16px;
+ padding: 16px;
+ }
+ .panel {
+ border-radius: 4px;
+ overflow: hidden;
+ }
+ .panel.collapsed .panel-content {
+ display: none;
+ }
+ .panel-header {
+ background: #1a1a1a;
+ padding: 8px 12px;
+ font-size: 12px;
+ font-weight: bold;
+ border-bottom: 1px solid #404040;
+ }
+ .panel-content {
+ padding: 12px;
+ font-size: 11px;
+ }
+ .panel-content table {
+ width: 100%;
+ border-collapse: collapse;
+ }
+ .panel-content th {
+ text-align: left;
+ padding: 4px;
+ font-size: 10px;
+ color: #808080;
+ border-bottom: 1px solid #404040;
+ }
+ .panel-content td {
+ padding: 4px;
+ font-size: 10px;
+ }
+ .panel-content tr:hover {
+ background: #1a1a1a;
+ }
+ .layer-buttons {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 6px;
+ margin-bottom: 12px;
+ }
+ .layer-buttons button {
+ padding: 6px 12px;
+ font-size: 10px;
+ }
+ .layer-buttons button.active {
+ background: #4a9eff;
+ border-color: #4a9eff;
+ color: #1a1a1a;
+ }
+ .layer-buttons button:disabled:hover {
+ border-color: #404040;
+ background: #1a1a1a;
+ }
+ .layer-grid {
+ display: grid;
+ grid-template-columns: repeat(4, 1fr);
+ gap: 4px;
+ margin-bottom: 12px;
+ }
+ .layer-view {
+ aspect-ratio: 1;
+ background: #1a1a1a;
+ border: 1px solid #404040;
+ display: flex;
+ flex-direction: column;
+ overflow: hidden;
+ }
+ .layer-preview {
+ background: #1a1a1a;
+ border: 1px solid #404040;
+ display: flex;
+ flex-direction: column;
+ overflow: hidden;
+ margin-top: 8px;
+ }
+ .layer-preview canvas {
+ width: 100%;
+ height: 100%;
+ image-rendering: pixelated;
+ }
+ .layer-view.active {
+ border: 2px solid #ffffff;
+ }
+ .layer-view canvas {
+ cursor: pointer;
+ }
+ .layer-view-label {
+ background: #2a2a2a;
+ padding: 4px;
+ font-size: 9px;
+ text-align: center;
+ border-bottom: 1px solid #404040;
+ }
+ .layer-view canvas {
+ width: 100%;
+ height: 100%;
+ image-rendering: pixelated;
+ }
+ canvas {
+ max-width: 100%;
+ max-height: 100%;
+ image-rendering: pixelated;
+ box-shadow: 0 4px 12px rgba(0,0,0,0.5);
+ }
+ .footer {
+ background: #2a2a2a;
+ border-top: 1px solid #404040;
+ font-size: 11px;
+ display: flex;
+ flex-direction: column;
+ gap: 8px;
+ }
+ .footer-top {
+ padding: 12px 16px 0;
+ display: flex;
+ justify-content: space-between;
+ }
+ .status { color: #4a9eff; }
+ .shortcuts { color: #808080; }
+ .console {
+ background: #1a1a1a;
+ padding: 8px 16px;
+ font-family: 'Courier New', monospace;
+ font-size: 10px;
+ color: #808080;
+ max-height: 100px;
+ overflow-y: auto;
+ border-top: 1px solid #404040;
+ }
+ .console-line { margin: 2px 0; }
+ .console-line.error { color: #ff4a4a; }
+ .console-line.info { color: #4a9eff; }
+ </style>
+</head>
+<body>
+ <div class="header">
+ <h1>CNN v2 Testing Tool</h1>
+ </div>
+ <video id="videoSource" muted loop></video>
+ <div class="content">
+ <div class="left-sidebar">
+ <input type="file" id="weightsFile" accept=".bin" style="display: none;">
+ <div class="drop-zone" id="weightsDrop" onclick="document.getElementById('weightsFile').click()">
+ Drop .bin Weights or Click to Browse
+ </div>
+ <div class="panel" id="weightsInfoPanel">
+ <div class="panel-header">Weights Info</div>
+ <div class="panel-content" id="weightsInfo">
+ <p style="color: #808080; text-align: center;">No weights loaded</p>
+ </div>
+ </div>
+ <div class="panel" id="weightsVizPanel" style="display: none;">
+ <div class="panel-header">Weights Visualization</div>
+ <div class="panel-content" id="weightsViz">
+ <div class="layer-buttons" id="weightsLayerButtons"></div>
+ <canvas id="weightsCanvas" style="width: 100%; image-rendering: pixelated; border: 1px solid #404040;"></canvas>
+ </div>
+ </div>
+ <div class="panel">
+ <div class="panel-content">
+ <label for="mipLevel" style="font-size: 11px;">Mip Level:</label>
+ <select id="mipLevel" style="width: 100%; background: #1a1a1a; color: #e0e0e0; border: 1px solid #404040; padding: 4px; margin-top: 4px;">
+ <option value="0">Mip 0 (original)</option>
+ <option value="1">Mip 1 (half res)</option>
+ <option value="2">Mip 2 (quarter res)</option>
+ </select>
+ </div>
+ </div>
+ </div>
+ <div class="main" id="mainDrop">
+ <div class="bottom-controls-float">
+ <div id="videoControls">
+ <button id="playPauseBtn" disabled>Play</button>
+ <button id="stepBackBtn" disabled>◄ Frame</button>
+ <button id="stepForwardBtn" disabled>Frame ►</button>
+ </div>
+ <div class="control-group">
+ <label>Blend:</label>
+ <input type="range" id="blend" min="0" max="1" step="0.01" value="1.0">
+ <span id="blendValue">1.0</span>
+ </div>
+ <div class="control-group">
+ <label>Depth:</label>
+ <input type="range" id="depth" min="0" max="1" step="0.01" value="1.0">
+ <span id="depthValue">1.0</span>
+ </div>
+ <button id="savePngBtn">Save PNG</button>
+ </div>
+ <canvas id="canvas"></canvas>
+ </div>
+ <div class="sidebar">
+ <div class="panel" style="flex: 1; display: flex; flex-direction: column; min-height: 0;">
+ <div class="panel-header">Layer Visualization</div>
+ <div class="panel-content" id="layerViz" style="flex: 1; overflow: hidden;">
+ <p style="color: #808080; text-align: center;">Load image + weights</p>
+ </div>
+ </div>
+ </div>
+ </div>
+ <div class="footer">
+ <div class="footer-top">
+ <span class="status" id="status">Drop PNG/video anywhere to begin</span>
+ <span class="shortcuts">[SPACE] Original | [D] Diff (×10)</span>
+ </div>
+ <div class="console" id="console"></div>
+ </div>
+
+ <script>
+// ============================================================================
+// EMBEDDED WEIGHTS & CONSTANTS
+// ============================================================================
+
+// Default pre-trained weights (base64-encoded binary format)
+// Version 2: 4 layers (3×3, 5×5, 3×3, 3×3), 2496 f16 weights, mip_level=2
+const DEFAULT_WEIGHTS_B64 = 'Q05OMgIAAAAEAAAAwAkAAAIAAAADAAAADAAAAAQAAAAAAAAAsAEAAAUAAAAMAAAABAAAALABAACwBAAAAwAAAAwAAAAEAAAAYAYAALABAAADAAAADAAAAAQAAAAQCAAAsAEAAAU3faplMDmtR7gnMLqt6bSrLM4RCa/En4q257kVsmWz57aSHJMxz6wILJC0tLdBriWww7IULUehCClCo60dBiu1nWqsf60ZKn6ktCWKjrswATSfLwQunzJjKKWkN6hxLTMwbS2DJvgvUjFDL1YsQDFFL78ysC5OL/cvxC2kJ6qh0i1BLH2rzCrcKFUoeixTqwwopjD+rXmewCY6sYUtXCwwsaKqGjBcqoykKigRJYStaqjMp+siPi1BLI+tGatfK5Ii6C1qLY0tYSGFKz4wpzNdH1QuJDKmMJi0lLVAs0y2Q7YWtY21fLXusf+n8LDSsaethK3drB4rtSROKYOrLK53qrqu0REYLEUuVy1qEqohDSzgqk4sDKKSKi0clKcVKvupJ69rKTmw8q7qptatQK7OsFUw5Z5JKJ4udSp9LLQeui87LbcxljEgJ6Iw75jDLfUvIjCxnh0g763Lq/ItMqzDqP0sXCRcqnkl9qDlJUStSyR8oTuwA616IrAnNqo5JS4qDKeILmahyaHZI48tryiajuEs0aghLBcuny+aovQpAhj6Kqkwdy+8MZ0wLzBvKBStsrRAKJez+raaKAotBiVSqZqyk7b2sHO1e7cJsfGmQLACpWizBLP9LnWxYLWoJPeb/CY5ISokXqynJ4qtG6K1qpesL6zGqYssIDJRpnErRi3RL9kh1zBFLPkdGSNvKtEuvyywmgilbC43LNovbywCKj4pFzEbMmMuly2gMFYscCgzliIomSqZnpSnyK3hJJKsAasgJGMrfCyNqXwpqaYNq14wiyzWLrSn/yLbqm+tnauOpkKtRKdCrBcYQS0dnGAveqeBrD8sMiGpLkAugzEaLM6lLzAkL5YydzYnqGo15zh2MuSwJK0nqxI04jZ5LAs2TjilNeSc3yANLecrCzBCprUvfjUHMWCuFrAkItyq/an0JSUnvKnrrAosv5CRrTGvQKesntuur6v2rsyxzbCAsHYn1y5GrAGsASYUmawrpSLooRSy86sBqmaxAq67sD0lJalOKxOtkqx8H+wqgygMLhup8SzNKZuhcafWKUKs567KI1opDCsoplatAykJpc+skavUrK4p2iznLlMqcig4Le6mDKiaJpIsMiOgLGOtQqI7sFGworKfsTOq86ZIlru0dLCEoMqq4KzsI6I2MzixMocqSym8MwQtT7Njqrwy26rEthe2nTGxL/Gq+az8MPg1Tq6EqXmslqyArkKs/S73MqEwmyuzrUUxejLhKYaw0yUlMzgxAZULsZ4rhq8ssgarCjDTrPop0ywBLswwjbT7MMAxdq2fsEC04DZoOIovG7G4LwM1gTNnKDsuEbByrzyxvLLBKJgkGDQANSMy66wVrM21ebURriAluK5quFa3wLBsK2wvaDU7OEg3RDGWKVUzpTfPNG+tbrGcr3ytRKosr7yuCbB2rV6gZq3msWmtjqvmoNurP6YXrOIpf6l/J2irl6/iqK2jy6MCLkkhjSDQoAWWACo1JrWjP6nvKvmthay+KJ6rUqoKqaatHKyJrUOarydBo5yu/CUaKFoxFCW1CNgpri2WK02kgqvYqkotwqlIrdiiEa1aKZ2tXa6mrkax4KkYKp2vcKgErYsi2RvbqWapU6EAnMyqtyPBpYwdZyVZkwGl1yhhJ2QBPaUJqMmMJJ54IikpcqmUHzmacCDzq1Cr3yR9n8aizKlWKFiogapBFlknrimnHmemDqbVKHciNRyII5AsxZ0+Lf0Xmyh7LMIqDS2KK9EkxyxRHKgp2iL9K0QfxCwGLLEuwiqrLcWob6xpppasp6+lotypGrC9qdmpPKUuplagES2cpSyrsSyHJTMi3Kk4KWAlSCaqKNMtR626rKaoj6koI1wqeivGI9cpuqQ9KQUkZyEJKOmquyW0JymirSjhprWgkBpKLFykzZyloWSrNKxrGaCtMi1MqL6t56lLqu+wbbTetYkqYDR1rB0wqir/sWQwNas8N9E4wq+9I6WwT6xuMDy1yC9tM/Kwka+btK8vJisnIJWeUa30LRkwDaqIsNqzWK9lLnEzKjEMqYMuWy8uMs0qI6xKLjcvxicEqYCv06zrrLusKK/lMeMz8CyCMmqxO7AtNpW38zFzL5i2Wq19tkCuBaTlt8Kv85Mlsg6wWLfgstutzDJVNAqZxCywrQgspDYOMS0mGbQCuf63QS7GJ4GsBLizuRS0mKyiKKMkBbLXseCufCr4qKUpah7Vqh8tV6eqLLQoGy1bMNEu6i4fMD4wZSvbjwOpmCBzLMmeJKddoYqkIic6qpqRY6nNqDiwIq5dqcmndqbnKnGkSCjmKBUsriySrHWsZyTaG7smSKxAIwolIi2zLX6unK5KqXCwKq03qyarcKWMqQmmd6tIodWtH6UvLg2tTadPJOOp2iGgny0ufyy+L7AvNClhpiEpC6qMqqMp7KTopJ4mmB2ylM6mrKhfKiQrTyiiKdGoQqjKJ6Umxip/qDiq/ChgKtmqIiwOr+CunZF7Kfot36poqkcthCx+Ksapg5T5pn0oNqOPq4osMSbSqQQmGqgXKhEl3yV1piyswazLK7QoQBTaqU8lIS13Ldch+qQqJ2AsPKfmp3Ink5Z2HhosR5z4qLIoGqkNLCct2Ck3KPGnUC0oJBQq7agOKyaq0qsqpAap8SylLg4qriy6M3MqKCtdKpMjSi86KigsGCz/n2erEyu7J/QRVCkpILUwcC35LI8qxiw6Knoq5jAAKo8wnieqLF0vVTAYMZw4Jyx2t/ayTjGWMoGzKbwus1w4QRxeJse1dTGSNJGwmCrEJV8uQKygKe4gjSqkrLeydiaMroS0FrQms8Uygi28qe2uXS2Ko4q1d7ZxszEpiDSBMoc0STWpNc0xJKSvrMWm6bCKsOC3CrEOJNC1Ga5Qubi7U6/+NRQ0AqnSuFoySDmKtJS0b7KcNAMmqi45IbMvGzjeMg2qSioPKVWtSK6EpaA1UTckMt2m16nwM5E2oDHBsZ+pniVpMc4vQy1epXkqHifBl7Mu36T/KzQorix4JAOmWyqJFVUqq67doiot2CxYME8i2JxVKhQt5ioYJsWp1KiSpL0lhq1JpWAgbCweKW2o1CrCIMsrcghkHUqW3hiTI5osYqMlB+WaLy0uKNUooKx4qdEezqRlJEapyKuUoEmoZyT7nqcoo6v3n4yqZaGcpNElwij3IkinQiAFIFQK2ygqIoKsiZxEI6ukqCf7KFSkgqSTqjEq8JZLJPufXKmFkaEj36lCKj2qURxfKkQouaqQhRIrGSmepKin7Cl8KEcuKI+ip4Evz6xIF0woVK/yHLyfLSj0ny+oWywSJHWmQaEomWos6ZTMpPWlY61pqLelZqYGpAidcyzQE5kneBr1pnQkJSwIqWYpIabdKA8oHKroGeCnYplOKzAmC51LJ0emp6o+rXAofCkCKV4w4x1sKCYjrKAgKa0r+BcPJDMmP6o2JW4pIqqtm4srTqgHlLWlsBBepaqrKq27rBat9aTlot8qkaw2o5sl76ivKDkjNyjzKKWY5KlHrQCr8SjxquarXqrlKB2xyyfZL1Sqq7LWpxA04zZwMkyvUiyHMig1ay+GJqenVq1Ao1awVLHQnrEqxTD/LO8kKB+NH1grfKsPsY6u+aIELLaj4LBmLBU0wDOlM8ksdKjbqPSqQykHJmYodC+WMcYuSCJ7psYvNDTaLqWw/qy7Myw4xjTnMIouQTV9OJ81YSlbLiIx3TVuMUcokrDzI0ow8CQQr9IvDyxsLnk0OTVhLmmobLAULN4zkyyZsGC0LK01L3Upw52Jroywlix0MCwr5qkQJkot9aWzsYuui66HrHykMa9ZsDet96yBqXWvXbAXsraxIqgpsVOvtq5frF+iZa2WqROwcaP+qX2w+aW3rxWpI7Bwrlqu5K0LrxexX7DUrfOvhK3QrUGwP7BrsY2tU6yWr8qkpK18rn2rHCbloYmfaqM1nfSr7Sn1qjuk2KT2qyem4KXJJ4MdxaidqPWsa58zKTSsoKXAJUymz6rJpv+oGKsOJo2hSicHqA4oOiiRmr4k0BxBq8Ui16jTKvyq7ijmqHcpZanhHnGfMikxIiEk7S4Yq90sfKWSoZyntKg/qh+nJiifnAyvlKeXJMIdViKeoxEjLKvZpXymAqkhraCofK5SnTGmLqdkq7mjYCD8qV0qQKo0qrUo+KsZKVSs0iaULFUI8qS0mlWtiiqbGBegACwBoAErhaW1qMwqHSxfKVKpp6x7poiweKxCrdkivK48sJewrKdArHYnqyhoHbUnsagYK58qSjAgMcUwsCt0K/4rLC7mJGwtvStOMFQu0SzuJQUsBTBMLswqcJyEnVQsESn3ox2z9ai/qFqwES7tKP0vSChMoqQwVzR4LKaT+y/NK06q2y0LIi2wHrIcKZuzsrSHn/6xkrPssAovJzEipEQiDbDjr3SqIis5LGIoOSm6p1apeqGGrtAqJzCIJRuptqrApiktWTAwMB4xQizXKoIgASFFsLwweTHbLdQtqyzXoKYtay3SLeOke6wgoPWr/SpFKUEmDacWptSoMChJKm6s6azkHe+mfzFKKyamfi6bK/wr5atPqEMxUTAlKSeueiRxoSQjQqxQLRavgauKriOssymXLZOooa97pFoufTSppqgoVq05tEg196yCsQIy7bEitAItJ7RgtUEzxjGML/QmEKIlrPgjPDFaoTYoPDFcJRavtK4XrKmsk6zjsCwsTa4UsPQs9jI/I3ct1C6cMV+b5y7wJZ0tYTF9MGojdS/oLTShziM/MVmnxC8FKJUwRCUxIz8wiS4QLWipLCCYq9EseabMKnEll6kPqIawRq+xGcgjyCkgqKed7SB6qZcr6CwJLW+st6ePq7WuHycUrhqsSq7zsKuZtimgCXCrmKkqnIGp4LHNsX2wnqyBsH2xIbDhpwCzra1ss44wTCypKDCyyK23LRiwYKKPMJmxcaqZKcshCCYipoyxNa1Nsbwozi1+MB8lQ5mtsDel3jDnlbutxiPzsWmp5SpTHaqys7EstauTPqoRsOosf6g3sLOgeaAfKUIsWi/BJdosUSzdMM4pSy3kpGM0DjWvLWw0cjR4MWWqQaYMLo2rZSijJjstZiFaLBadMq0TseyjYi0VGsQt8yo5oZCgti/HMLciM6r3KgMk8K6OqKup9q0srT0xcaWMMMwra67qrhSfsZ3GrrIj2a2+pqSvdrEcrRQ0IDhgMB+PCDWVM8qjnJ5ZKOmw4C0dMGyuG6DGMQUvrq+Oq4UsTSzHMRg2ibbXs+Axa7N5sAqqnSoerQUmky8oKIiuUjGsoBitdKy9q6iw661pqg4thKnpkYmt+a3gseypGp5Co22fM6YSKJap66hwopmsmqhlrCMkZyiLL4KnGKupKvUmyCQbLFUrbSZerKahlaRoqCYm5SqYKW0rcS8WrAUkzaMcGlqpRK3bnresXy18IXapEKqHKFssXKCpKMUrfamapf4tKjBiKJGoU54HK+8q5qq4qVuiZiy4JuEsTixNMFQnlSSIIw4k1KzxpbMlDqyKqz6gra4SpcOw3a3Vq+qqC6tOq22eORvnpC8hRadkka2q/K7HHUiowawpqPInLyA0qYMlsihUqGGkWCb7K1WdWK5Dr5EhnKv5KHKlXqYnJ/2l9i0YKUYuMzHxpyCs/ChMkPEtwanxoFQqJi3Uq7Mseq3arXskWKc5pOAc7CZcqCwc5w7qKO4f3iaKIDsq/KRgLpWsQqn5rYYkxCWPoU0bx6hzGdkkqibtofEoxy8GpUupSCTiKiwvpij7LbiulqkErXetejFkL2+upqtUp0OwiLAPsdCpxLIlrKOyQ7C2r3utIg0drZEl2y6oLkquoaX4rCysAa9GDRCwKrHDsNivAbHsqtioqiGvrqgJE66Kqw4rzKyDKgaomp6TK2EsDyc0oOSol6NZJkmsvyxorMss5pR0KBquEixPpjsgXCpsnXQocq2MrfGmoivvLBeacahmLROpe6kcGCSfdC03qL6i6yitHHohrxzqq4UiP6JMqF8qThOshWAVUqHupDsoohQuJSkv/ywqLiwlNjG7o++hxi3vIKmleCdyrH6wYatdsPWsjLCNol+sSTDpryCptbBDK+qs4zBpLGc0Nqc1rdo09jX5MqsrHi2xKOad8igwJxAoeSsiqgkqdChcLOYxJzGlMkAsUzCuKzskTjAOKhuplqjHqf8wzDKYIGefNDISqd8pIC23Ltwu7zC9KgMsQDL/JcgrryYzLJ0oTSoyqpkmLax+KuejVyqxr08ulZ2XpyQr5yxRsEMpwzD0KmEqoihRC6mwF6xOplwmjSSmpMep0SvhpOEndCluqLyvtCGgo3unOyy9IXKtmZ9yIK8hlqohrEUtxh0XKH0sGi18p6coHa3Tow6psqa/JRUMU6yiKbUoXigQpo2i7C18q3ur6CnWrSateC3/KY+jlCJ6o6qr+x8VJUkSFadyAgGpji0xraytBSd+rYksTqDAHQAtxSjkqMAmNqxhqNesEi5uKsqlFqo9Kg6seizOrdusAasErjmtoKv8rb8ph6cYLnMmcKlCLJ6pjiuIKpkpKK1UKvyq3RhVpZac+izlrYitWB+DrI4omKOZKikiZS1Fqicf+q25rJmsqKrYrNGt0JWRLWel2KfLqQ==';
+
+// Reusable fullscreen quad vertex shader (2 triangles covering NDC)
+const FULLSCREEN_QUAD_VS = `
+@vertex
+fn vs_main(@builtin(vertex_index) idx: u32) -> @builtin(position) vec4<f32> {
+ var pos = array<vec2<f32>, 6>(
+ vec2<f32>(-1.0, -1.0), vec2<f32>(1.0, -1.0), vec2<f32>(-1.0, 1.0),
+ vec2<f32>(-1.0, 1.0), vec2<f32>(1.0, -1.0), vec2<f32>(1.0, 1.0)
+ );
+ return vec4<f32>(pos[idx], 0.0, 1.0);
+}`;
+
+// ============================================================================
+// WGSL SHADERS
+// ============================================================================
+
+// Static features: 7D parametric features (RGBD + UV + sin(10*uv_x) + bias)
+const STATIC_SHADER = `
+@group(0) @binding(0) var input_tex: texture_2d<f32>;
+@group(0) @binding(1) var linear_sampler: sampler;
+@group(0) @binding(2) var depth_tex: texture_2d<f32>;
+@group(0) @binding(3) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(4) var<uniform> mip_level: u32;
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(input_tex);
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; }
+
+ // Use normalized UV coords with linear sampler (bilinear filtering)
+ let uv = (vec2<f32>(coord) + 0.5) / vec2<f32>(dims);
+ let rgba = textureSampleLevel(input_tex, linear_sampler, uv, f32(mip_level));
+
+ let p0 = rgba.r;
+ let p1 = rgba.g;
+ let p2 = rgba.b;
+ let p3 = textureLoad(depth_tex, coord, 0).r;
+
+ let uv_x = f32(coord.x) / f32(dims.x);
+ let uv_y = f32(coord.y) / f32(dims.y);
+ let sin20_y = sin(20.0 * uv_y);
+ let bias = 1.0;
+
+ let packed = vec4<u32>(
+ pack2x16float(vec2<f32>(p0, p1)),
+ pack2x16float(vec2<f32>(p2, p3)),
+ pack2x16float(vec2<f32>(uv_x, uv_y)),
+ pack2x16float(vec2<f32>(sin20_y, bias))
+ );
+ textureStore(output_tex, coord, packed);
+}`;
+
+const CNN_SHADER = `
+struct LayerParams {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32,
+ is_output_layer: u32,
+ blend_amount: f32,
+ is_layer_0: u32,
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>;
+@group(0) @binding(4) var<uniform> params: LayerParams;
+@group(0) @binding(5) var original_input: texture_2d<f32>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
+}
+
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u,
+ 0u
+ );
+}
+
+fn get_weight(idx: u32) -> f32 {
+ let pair_idx = idx / 2u;
+ let packed = weights_buffer[pair_idx];
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; }
+
+ let kernel_size = params.kernel_size;
+ let in_channels = params.in_channels; // Always 12 (4 prev + 8 static)
+ let out_channels = params.out_channels; // Always 4
+ let weight_offset = params.weight_offset;
+ let is_output = params.is_output_layer != 0u;
+ let kernel_radius = i32(kernel_size / 2u);
+
+ let static_feat = unpack_static_features(coord);
+
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
+ var sum: f32 = 0.0;
+ for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
+ for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ let ky_idx = u32(ky + kernel_radius);
+ let kx_idx = u32(kx + kernel_radius);
+ let spatial_idx = ky_idx * kernel_size + kx_idx;
+
+ // Previous layer channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ i * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * layer_local[i];
+ }
+
+ // Static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ (4u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
+ }
+ }
+ }
+
+ if (is_output || params.is_layer_0 != 0u) {
+ output[c] = 1.0 / (1.0 + exp(-sum)); // Sigmoid [0,1]
+ } else {
+ output[c] = max(0.0, sum); // ReLU
+ }
+ }
+
+ if (is_output) {
+ let original = textureLoad(original_input, coord, 0).rgb;
+ let result_rgb = vec3<f32>(output.x, output.y, output.z);
+ let blended = mix(original, result_rgb, params.blend_amount);
+ output.x = blended.r;
+ output.y = blended.g;
+ output.z = blended.b;
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}`;
+
+const DISPLAY_SHADER = `
+@group(0) @binding(0) var result_tex: texture_2d<u32>;
+@group(0) @binding(1) var original_tex: texture_2d<f32>;
+@group(0) @binding(2) var<uniform> mode: u32;
+
+@vertex
+fn vs_main(@builtin(vertex_index) idx: u32) -> @builtin(position) vec4<f32> {
+ var pos = array<vec2<f32>, 6>(
+ vec2<f32>(-1.0, -1.0), vec2<f32>(1.0, -1.0), vec2<f32>(-1.0, 1.0),
+ vec2<f32>(-1.0, 1.0), vec2<f32>(1.0, -1.0), vec2<f32>(1.0, 1.0)
+ );
+ return vec4<f32>(pos[idx], 0.0, 1.0);
+}
+
+@fragment
+fn fs_main(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
+ let coord = vec2<i32>(pos.xy);
+ let packed = textureLoad(result_tex, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let result = vec3<f32>(v0.x, v0.y, v1.x);
+
+ if (mode == 0u) {
+ return vec4<f32>(result, 1.0);
+ } else if (mode == 1u) {
+ let original = textureLoad(original_tex, coord, 0).rgb;
+ return vec4<f32>(original, 1.0);
+ } else {
+ let original = textureLoad(original_tex, coord, 0).rgb;
+ let diff = abs(result - original) * 10.0;
+ return vec4<f32>(diff, 1.0);
+ }
+}`;
+
+const LAYER_VIZ_SHADER = `
+@group(0) @binding(0) var layer_tex: texture_2d<u32>;
+@group(0) @binding(1) var<uniform> viz_params: vec2<f32>; // x=channel_idx, y=scale
+
+@vertex
+fn vs_main(@builtin(vertex_index) idx: u32) -> @builtin(position) vec4<f32> {
+ var pos = array<vec2<f32>, 6>(
+ vec2<f32>(-1.0, -1.0), vec2<f32>(1.0, -1.0), vec2<f32>(-1.0, 1.0),
+ vec2<f32>(-1.0, 1.0), vec2<f32>(1.0, -1.0), vec2<f32>(1.0, 1.0)
+ );
+ return vec4<f32>(pos[idx], 0.0, 1.0);
+}
+
+@fragment
+fn fs_main(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
+ let coord = vec2<i32>(pos.xy);
+ let dims = textureDimensions(layer_tex);
+
+ let channel = u32(viz_params.x);
+
+ // DEBUG MODE 1: Texture coordinates (channel 10)
+ if (channel == 10u) {
+ let uv = vec2<f32>(f32(coord.x) / f32(dims.x), f32(coord.y) / f32(dims.y));
+ return vec4<f32>(uv.x, uv.y, 0.0, 1.0);
+ }
+
+ let packed = textureLoad(layer_tex, coord, 0);
+
+ // DEBUG MODE 2: Raw packed data (channel 11)
+ if (channel == 11u) {
+ let raw_val = f32(packed.x) / 4294967295.0;
+ return vec4<f32>(raw_val, raw_val, raw_val, 1.0);
+ }
+
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+
+ // DEBUG MODE 3: First unpacked value (channel 12)
+ if (channel == 12u) {
+ return vec4<f32>(v0.x, v0.x, v0.x, 1.0);
+ }
+
+ var channels: array<f32, 8>;
+ channels[0] = v0.x;
+ channels[1] = v0.y;
+ channels[2] = v1.x;
+ channels[3] = v1.y;
+ channels[4] = v2.x;
+ channels[5] = v2.y;
+ channels[6] = v3.x;
+ channels[7] = v3.y;
+
+ let scale = viz_params.y;
+
+ let idx = min(channel, 7u);
+ let raw = channels[idx];
+
+ // Apply scale: multiply and clamp to [0, 1]
+ let val = clamp(raw * scale, 0.0, 1.0);
+
+ return vec4<f32>(val, val, val, 1.0);
+}`;
+
+class CNNTester {
+ constructor() {
+ this.canvas = document.getElementById('canvas');
+ this.status = document.getElementById('status');
+ this.console = document.getElementById('console');
+ this.image = null;
+ this.video = document.getElementById('videoSource');
+ this.weights = null;
+ this.viewMode = 0;
+ this.blendAmount = 1.0;
+ this.depth = 1.0;
+ this.currentLayerIdx = null;
+ this.currentChannelOffset = null;
+ this.isVideo = false;
+ this.fps = 30;
+ this.isProcessing = false;
+ this.mipLevel = 0;
+ this.selectedChannel = 0;
+ this.init();
+ }
+
+ log(msg, type = 'info') {
+ const line = document.createElement('div');
+ line.className = `console-line ${type}`;
+ line.textContent = `[${new Date().toLocaleTimeString()}] ${msg}`;
+ this.console.appendChild(line);
+ this.console.scrollTop = this.console.scrollHeight;
+ }
+
+ async init() {
+ if (!navigator.gpu) {
+ this.setStatus('WebGPU not supported', true);
+ this.log('WebGPU not supported in this browser', 'error');
+ return;
+ }
+
+ try {
+ this.adapter = await navigator.gpu.requestAdapter();
+ this.device = await this.adapter.requestDevice();
+ this.context = this.canvas.getContext('webgpu');
+ this.format = navigator.gpu.getPreferredCanvasFormat();
+ this.log('WebGPU initialized successfully');
+ } catch (e) {
+ this.setStatus(`GPU init failed: ${e.message}`, true);
+ this.log(`GPU initialization failed: ${e.message}`, 'error');
+ }
+ }
+
+ setStatus(msg, isError = false) {
+ this.status.textContent = msg;
+ this.status.style.color = isError ? '#ff4a4a' : '#4a9eff';
+ }
+
+ // Get current source dimensions (video or image)
+ getDimensions() {
+ if (this.isVideo) {
+ return { width: this.video.videoWidth, height: this.video.videoHeight };
+ }
+ return { width: this.image.width, height: this.image.height };
+ }
+
+ // Enable/disable video playback controls
+ setVideoControlsEnabled(enabled) {
+ ['playPauseBtn', 'stepBackBtn', 'stepForwardBtn'].forEach(id =>
+ document.getElementById(id).disabled = !enabled
+ );
+ }
+
+ parseWeights(buffer) {
+ const view = new DataView(buffer);
+ const magic = view.getUint32(0, true);
+ if (magic !== 0x32_4E_4E_43) {
+ throw new Error('Invalid .bin file (bad magic)');
+ }
+
+ const version = view.getUint32(4, true);
+ const numLayers = view.getUint32(8, true);
+ const totalWeights = view.getUint32(12, true);
+
+ // Version 2: added mip_level field (20-byte header)
+ let mipLevel = 0;
+ let headerSize = 16;
+ if (version === 2) {
+ mipLevel = view.getUint32(16, true);
+ headerSize = 20;
+ this.log(`Binary header: version=${version}, layers=${numLayers}, weights=${totalWeights}, mip_level=${mipLevel}`);
+ } else if (version === 1) {
+ this.log(`Binary header: version=${version}, layers=${numLayers}, weights=${totalWeights}`);
+ } else {
+ throw new Error(`Unsupported binary version: ${version}`);
+ }
+
+ const layers = [];
+ for (let i = 0; i < numLayers; i++) {
+ const offset = headerSize + i * 20;
+ const layer = {
+ kernelSize: view.getUint32(offset, true),
+ inChannels: view.getUint32(offset + 4, true),
+ outChannels: view.getUint32(offset + 8, true),
+ weightOffset: view.getUint32(offset + 12, true),
+ weightCount: view.getUint32(offset + 16, true),
+ };
+ layers.push(layer);
+ this.log(` Layer ${i}: ${layer.inChannels}→${layer.outChannels}, kernel=${layer.kernelSize}×${layer.kernelSize}, weights=${layer.weightCount}`);
+ }
+
+ const weightsOffset = headerSize + numLayers * 20;
+ const weights = new Uint32Array(buffer.slice(weightsOffset));
+
+ // Calculate min/max per layer
+ for (let i = 0; i < numLayers; i++) {
+ const layer = layers[i];
+ let min = Infinity, max = -Infinity;
+ const startIdx = layer.weightOffset;
+ const endIdx = startIdx + layer.weightCount;
+
+ for (let j = startIdx; j < endIdx; j++) {
+ const pairIdx = Math.floor(j / 2);
+ const packed = weights[pairIdx];
+ const unpacked = this.unpackF16(packed);
+ const val = (j % 2 === 0) ? unpacked[0] : unpacked[1];
+ min = Math.min(min, val);
+ max = Math.max(max, val);
+ }
+
+ layer.min = min;
+ layer.max = max;
+ this.log(` Layer ${i} range: [${min.toFixed(4)}, ${max.toFixed(4)}]`);
+ }
+
+ let nonZero = 0;
+ for (let i = 0; i < weights.length; i++) {
+ if (weights[i] !== 0) nonZero++;
+ }
+ this.log(` Weight buffer: ${weights.length} u32 (${nonZero} non-zero)`);
+
+ return { version, layers, weights, mipLevel, fileSize: buffer.byteLength };
+ }
+
+ unpackF16(packed) {
+ const lo = packed & 0xFFFF;
+ const hi = (packed >> 16) & 0xFFFF;
+ const toFloat = (bits) => {
+ const sign = (bits >> 15) & 1;
+ const exp = (bits >> 10) & 0x1F;
+ const frac = bits & 0x3FF;
+ if (exp === 0) return (sign ? -1 : 1) * Math.pow(2, -14) * (frac / 1024);
+ if (exp === 31) return frac ? NaN : (sign ? -Infinity : Infinity);
+ return (sign ? -1 : 1) * Math.pow(2, exp - 15) * (1 + frac / 1024);
+ };
+ return [toFloat(lo), toFloat(hi)];
+ }
+
+ async loadImage(file) {
+ const img = await createImageBitmap(file);
+ this.image = img;
+ this.isVideo = false;
+ this.canvas.width = img.width;
+ this.canvas.height = img.height;
+ this.setVideoControlsEnabled(false);
+ this.log(`Loaded image: ${file.name} (${img.width}×${img.height})`);
+ if (this.weights) {
+ this.setStatus(`Ready: ${img.width}×${img.height}`);
+ this.run();
+ } else {
+ this.setStatus(`Image loaded (${img.width}×${img.height}) - drop .bin weights to process`);
+ this.displayOriginal();
+ }
+ }
+
+ // Video loading: wait for metadata, then first frame decode (readyState≥2)
+ async loadVideo(file) {
+ return new Promise((resolve, reject) => {
+ this.video.src = URL.createObjectURL(file);
+
+ this.video.onloadedmetadata = () => {
+ const w = this.video.videoWidth;
+ const h = this.video.videoHeight;
+ if (w === 0 || h === 0) {
+ reject(new Error('Video has invalid dimensions'));
+ return;
+ }
+
+ this.isVideo = true;
+ this.canvas.width = w;
+ this.canvas.height = h;
+ this.fps = 30;
+ this.log(`Loaded video: ${file.name} (${w}×${h}, ${this.video.duration.toFixed(1)}s)`);
+ this.setVideoControlsEnabled(true);
+
+ // Set up event handlers
+ this.video.onpause = () => { document.getElementById('playPauseBtn').textContent = 'Play'; };
+ this.video.onplay = () => { document.getElementById('playPauseBtn').textContent = 'Pause'; this.playbackLoop(); };
+
+ // Wait for first frame to be decoded before displaying
+ const displayFirstFrame = () => {
+ this.video.onseeked = () => { if (!this.isProcessing) this.processVideoFrame(); };
+ if (this.video.readyState >= 2) { // HAVE_CURRENT_DATA or better
+ if (this.weights) {
+ this.setStatus(`Ready: ${w}×${h}`);
+ this.processVideoFrame().then(() => resolve());
+ } else {
+ this.setStatus(`Video loaded - drop .bin weights to process`);
+ this.displayOriginal();
+ resolve();
+ }
+ } else {
+ setTimeout(displayFirstFrame, 50); // Poll until frame ready
+ }
+ };
+
+ this.video.onseeked = displayFirstFrame;
+ this.video.currentTime = 0;
+ };
+
+ this.video.onerror = () => reject(new Error('Failed to load video'));
+ });
+ }
+
+ // Video playback loop (non-realtime, drops frames if CNN slow)
+ playbackLoop() {
+ if (this.video.paused || this.video.ended) return;
+ if (!this.isProcessing) this.processVideoFrame();
+ requestAnimationFrame(() => this.playbackLoop());
+ }
+
+ // Process current video frame through CNN pipeline
+ async processVideoFrame() {
+ if (!this.weights || this.isProcessing) return;
+ this.isProcessing = true;
+ await this.run();
+ this.isProcessing = false;
+ }
+
+ // Video controls
+ togglePlayPause() {
+ this.video.paused ? this.video.play() : this.video.pause();
+ }
+
+ stepFrame(direction) {
+ if (!this.isVideo) return;
+ this.video.pause();
+ this.video.currentTime = Math.max(0, Math.min(this.video.duration,
+ this.video.currentTime + direction / this.fps));
+ }
+
+ async loadWeights(file) {
+ const buffer = await file.arrayBuffer();
+ this.weights = this.parseWeights(buffer);
+ this.weightsBuffer = buffer;
+ this.mipLevel = this.weights.mipLevel; // Set mip level from binary format
+ this.log(`Loaded weights: ${file.name} (${this.weights.layers.length} layers, ${(buffer.byteLength/1024).toFixed(1)} KB)`);
+
+ // Update UI dropdown to reflect loaded mip level
+ const mipLevelSelect = document.getElementById('mipLevel');
+ if (mipLevelSelect) {
+ mipLevelSelect.value = this.mipLevel.toString();
+ }
+
+ this.updateWeightsPanel();
+ if (this.image) {
+ this.setStatus(`Ready: ${this.image.width}×${this.image.height}`);
+ this.run();
+ } else {
+ this.setStatus('Weights loaded - drop PNG image to process');
+ }
+ }
+
+ updateWeightsPanel() {
+ const panel = document.getElementById('weightsInfo');
+ const { version, layers, mipLevel, fileSize } = this.weights;
+
+ let html = `
+ <div style="margin-bottom: 12px;">
+ <div><strong>File Size:</strong> ${(fileSize / 1024).toFixed(2)} KB</div>
+ <div><strong>Version:</strong> ${version}</div>
+ <div><strong>CNN Layers:</strong> ${layers.length}</div>
+ <div><strong>Mip Level:</strong> ${mipLevel} (p0-p3 features)</div>
+ <div style="font-size: 9px; color: #808080; margin-top: 4px;">Static features (input) + ${layers.length} conv layers</div>
+ </div>
+ <table>
+ <thead>
+ <tr>
+ <th>Layer</th>
+ <th>Size</th>
+ <th>Weights</th>
+ <th>Min</th>
+ <th>Max</th>
+ </tr>
+ </thead>
+ <tbody>
+ `;
+
+ // Display layers as "Layer 0", "Layer 1", etc. (matching codebase convention)
+ for (let i = 0; i < layers.length; i++) {
+ const l = layers[i];
+ html += `
+ <tr>
+ <td>Layer ${i}</td>
+ <td>${l.inChannels}→${l.outChannels} (${l.kernelSize}×${l.kernelSize})</td>
+ <td>${l.weightCount}</td>
+ <td>${l.min.toFixed(3)}</td>
+ <td>${l.max.toFixed(3)}</td>
+ </tr>
+ `;
+ }
+
+ html += `
+ </tbody>
+ </table>
+ `;
+
+ panel.innerHTML = html;
+
+ // Show weights visualization panel and create layer buttons
+ const weightsVizPanel = document.getElementById('weightsVizPanel');
+ weightsVizPanel.style.display = 'block';
+
+ const weightsLayerButtons = document.getElementById('weightsLayerButtons');
+ let buttonsHtml = '';
+ for (let i = 0; i < layers.length; i++) {
+ buttonsHtml += `<button onclick="tester.visualizeWeights(${i})" id="weightsBtn${i}">Layer ${i}</button>`;
+ }
+ weightsLayerButtons.innerHTML = buttonsHtml;
+
+ // Auto-select first layer
+ this.visualizeWeights(0);
+ }
+
+ generateMipmaps(texture, width, height) {
+ if (!this.mipmapPipeline) {
+ const mipmapShader = FULLSCREEN_QUAD_VS + `
+ @group(0) @binding(0) var src: texture_2d<f32>;
+ @fragment
+ fn fs_main(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
+ let coord = vec2<i32>(i32(pos.x) * 2, i32(pos.y) * 2);
+ var sum = vec4<f32>(0.0);
+ for (var y: i32 = 0; y < 2; y++) {
+ for (var x: i32 = 0; x < 2; x++) {
+ sum += textureLoad(src, coord + vec2<i32>(x, y), 0);
+ }
+ }
+ return sum * 0.25;
+ }
+ `;
+ this.mipmapPipeline = this.device.createRenderPipeline({
+ layout: 'auto',
+ vertex: { module: this.device.createShaderModule({ code: mipmapShader }), entryPoint: 'vs_main' },
+ fragment: {
+ module: this.device.createShaderModule({ code: mipmapShader }),
+ entryPoint: 'fs_main',
+ targets: [{ format: 'rgba8unorm' }]
+ }
+ });
+ }
+
+ const encoder = this.device.createCommandEncoder();
+
+ for (let mip = 1; mip < 3; mip++) {
+ const mipWidth = Math.max(1, width >> mip);
+ const mipHeight = Math.max(1, height >> mip);
+
+ const bindGroup = this.device.createBindGroup({
+ layout: this.mipmapPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: texture.createView({ baseMipLevel: mip - 1, mipLevelCount: 1 }) }
+ ]
+ });
+
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: texture.createView({ baseMipLevel: mip, mipLevelCount: 1 }),
+ loadOp: 'clear',
+ storeOp: 'store'
+ }]
+ });
+
+ renderPass.setPipeline(this.mipmapPipeline);
+ renderPass.setBindGroup(0, bindGroup);
+ renderPass.setViewport(0, 0, mipWidth, mipHeight, 0, 1);
+ renderPass.draw(6);
+ renderPass.end();
+ }
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+
+ displayOriginal() {
+ const source = this.isVideo ? this.video : this.image;
+ if (!source || !this.device) return;
+
+ const { width, height } = this.getDimensions();
+ this.context.configure({ device: this.device, format: this.format });
+
+ const inputTex = this.device.createTexture({
+ size: [width, height],
+ format: 'rgba8unorm',
+ usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST | GPUTextureUsage.RENDER_ATTACHMENT
+ });
+
+ this.device.queue.copyExternalImageToTexture(
+ { source: source },
+ { texture: inputTex },
+ [width, height]
+ );
+
+ const simpleShader = FULLSCREEN_QUAD_VS + `
+ @group(0) @binding(0) var tex: texture_2d<f32>;
+ @fragment
+ fn fs_main(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
+ let coord = vec2<i32>(pos.xy);
+ return textureLoad(tex, coord, 0);
+ }
+ `;
+
+ const pipeline = this.device.createRenderPipeline({
+ layout: 'auto',
+ vertex: { module: this.device.createShaderModule({ code: simpleShader }), entryPoint: 'vs_main' },
+ fragment: {
+ module: this.device.createShaderModule({ code: simpleShader }),
+ entryPoint: 'fs_main',
+ targets: [{ format: this.format }]
+ }
+ });
+
+ const bindGroup = this.device.createBindGroup({
+ layout: pipeline.getBindGroupLayout(0),
+ entries: [{ binding: 0, resource: inputTex.createView() }]
+ });
+
+ const encoder = this.device.createCommandEncoder();
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: this.context.getCurrentTexture().createView(),
+ loadOp: 'clear',
+ storeOp: 'store'
+ }]
+ });
+ renderPass.setPipeline(pipeline);
+ renderPass.setBindGroup(0, bindGroup);
+ renderPass.draw(6);
+ renderPass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+
+ // Run CNN inference pipeline on current source (image or video frame)
+ async run() {
+ const t0 = performance.now();
+ const source = this.isVideo ? this.video : this.image;
+ if (!source) return;
+ const { width, height } = this.getDimensions();
+
+ this.context.configure({ device: this.device, format: this.format });
+
+ // Create persistent input texture for original view with mipmaps
+ if (this.inputTexture) this.inputTexture.destroy();
+ this.inputTexture = this.device.createTexture({
+ size: [width, height],
+ format: 'rgba8unorm',
+ mipLevelCount: 3,
+ usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST | GPUTextureUsage.RENDER_ATTACHMENT
+ });
+
+ this.device.queue.copyExternalImageToTexture(
+ { source: source },
+ { texture: this.inputTexture, mipLevel: 0 },
+ [width, height]
+ );
+
+ // Generate mipmaps
+ this.generateMipmaps(this.inputTexture, width, height);
+
+ const staticTex = this.device.createTexture({
+ size: [width, height],
+ format: 'rgba32uint',
+ usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_SRC
+ });
+
+ // Create one texture per layer output (static + all CNN layers)
+ this.layerOutputs = [];
+ const numLayers = this.weights.layers.length + 1; // +1 for static features
+ const layerTextures = [];
+ for (let i = 0; i < numLayers; i++) {
+ layerTextures.push(this.device.createTexture({
+ size: [width, height],
+ format: 'rgba32uint',
+ usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST
+ }));
+ }
+
+ // Ping-pong buffers for computation
+ const computeTextures = [
+ this.device.createTexture({
+ size: [width, height],
+ format: 'rgba32uint',
+ usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_SRC
+ }),
+ this.device.createTexture({
+ size: [width, height],
+ format: 'rgba32uint',
+ usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_SRC
+ })
+ ];
+
+ const weightsGPU = this.device.createBuffer({
+ size: this.weightsBuffer.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
+ });
+ this.device.queue.writeBuffer(weightsGPU, 0, this.weightsBuffer);
+ const staticPipeline = this.device.createComputePipeline({
+ layout: 'auto',
+ compute: { module: this.device.createShaderModule({ code: STATIC_SHADER }), entryPoint: 'main' }
+ });
+
+ const cnnPipeline = this.device.createComputePipeline({
+ layout: 'auto',
+ compute: { module: this.device.createShaderModule({ code: CNN_SHADER }), entryPoint: 'main' }
+ });
+
+ const displayPipeline = this.device.createRenderPipeline({
+ layout: 'auto',
+ vertex: { module: this.device.createShaderModule({ code: DISPLAY_SHADER }), entryPoint: 'vs_main' },
+ fragment: {
+ module: this.device.createShaderModule({ code: DISPLAY_SHADER }),
+ entryPoint: 'fs_main',
+ targets: [{ format: this.format }]
+ }
+ });
+
+ const encoder = this.device.createCommandEncoder();
+
+ const mipLevelBuffer = this.device.createBuffer({
+ size: 4,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
+ });
+ this.device.queue.writeBuffer(mipLevelBuffer, 0, new Uint32Array([this.mipLevel]));
+
+ if (!this.pointSampler) {
+ this.pointSampler = this.device.createSampler({
+ magFilter: 'linear',
+ minFilter: 'linear',
+ mipmapFilter: 'linear'
+ });
+ }
+
+ // Extract depth from alpha channel (or 1.0 if no alpha)
+ const depthTex = this.device.createTexture({
+ size: [width, height, 1],
+ format: 'r32float',
+ usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST
+ });
+
+ // Read image data to extract alpha channel
+ const tempCanvas = document.createElement('canvas');
+ tempCanvas.width = width;
+ tempCanvas.height = height;
+ const tempCtx = tempCanvas.getContext('2d');
+ tempCtx.drawImage(source, 0, 0, width, height);
+ const imageData = tempCtx.getImageData(0, 0, width, height);
+ const pixels = imageData.data;
+
+ // Extract alpha channel (RGBA format: every 4th byte)
+ const depthData = new Float32Array(width * height);
+ for (let i = 0; i < width * height; i++) {
+ depthData[i] = pixels[i * 4 + 3] / 255.0; // Alpha channel [0, 255] → [0, 1]
+ }
+
+ this.device.queue.writeTexture(
+ { texture: depthTex },
+ depthData,
+ { bytesPerRow: width * 4 },
+ [width, height, 1]
+ );
+
+ const staticBG = this.device.createBindGroup({
+ layout: staticPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: this.inputTexture.createView() },
+ { binding: 1, resource: this.pointSampler },
+ { binding: 2, resource: depthTex.createView() }, // Depth from alpha (matches training)
+ { binding: 3, resource: staticTex.createView() },
+ { binding: 4, resource: { buffer: mipLevelBuffer } }
+ ]
+ });
+
+ const staticPass = encoder.beginComputePass();
+ staticPass.setPipeline(staticPipeline);
+ staticPass.setBindGroup(0, staticBG);
+ staticPass.dispatchWorkgroups(Math.ceil(width / 8), Math.ceil(height / 8));
+ staticPass.end();
+
+ // Copy static features to persistent storage (visualization index 0, shown as Static 0-3 / Static 4-7)
+ encoder.copyTextureToTexture(
+ { texture: staticTex },
+ { texture: layerTextures[0] },
+ [width, height]
+ );
+ this.layerOutputs.push(layerTextures[0]);
+
+ let srcTex = staticTex;
+ let dstTex = computeTextures[0];
+
+ for (let i = 0; i < this.weights.layers.length; i++) {
+ const layer = this.weights.layers[i];
+ const isOutput = i === this.weights.layers.length - 1;
+
+ // Calculate absolute weight offset in f16 units (add header offset)
+ // Version 1: 4 u32 header, Version 2: 5 u32 header
+ const headerSizeU32 = (this.weights.version === 1) ? 4 : 5;
+ const headerOffsetU32 = headerSizeU32 + this.weights.layers.length * 5; // Header + layer info in u32
+ const absoluteWeightOffset = headerOffsetU32 * 2 + layer.weightOffset; // Convert to f16 units
+
+ const paramsData = new Uint32Array(7);
+ paramsData[0] = layer.kernelSize;
+ paramsData[1] = layer.inChannels;
+ paramsData[2] = layer.outChannels;
+ paramsData[3] = absoluteWeightOffset; // Use absolute offset
+ paramsData[4] = isOutput ? 1 : 0;
+ paramsData[6] = (i === 0) ? 1 : 0; // is_layer_0 flag
+
+ const paramsView = new Float32Array(paramsData.buffer);
+ paramsView[5] = this.blendAmount;
+
+ const paramsBuffer = this.device.createBuffer({
+ size: 28,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
+ });
+ this.device.queue.writeBuffer(paramsBuffer, 0, paramsData);
+
+ const cnnBG = this.device.createBindGroup({
+ layout: cnnPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: layerTextures[0].createView() },
+ { binding: 1, resource: srcTex.createView() },
+ { binding: 2, resource: dstTex.createView() },
+ { binding: 3, resource: { buffer: weightsGPU } },
+ { binding: 4, resource: { buffer: paramsBuffer } },
+ { binding: 5, resource: this.inputTexture.createView() }
+ ]
+ });
+
+ const cnnPass = encoder.beginComputePass();
+ cnnPass.setPipeline(cnnPipeline);
+ cnnPass.setBindGroup(0, cnnBG);
+ cnnPass.dispatchWorkgroups(Math.ceil(width / 8), Math.ceil(height / 8));
+ cnnPass.end();
+
+ [srcTex, dstTex] = [dstTex, srcTex];
+
+ // Copy CNN layer output to persistent storage for visualization
+ // i=0: Layer 0 → layerTextures[1]
+ // i=1: Layer 1 → layerTextures[2], etc.
+ encoder.copyTextureToTexture(
+ { texture: srcTex },
+ { texture: layerTextures[i + 1] },
+ [width, height]
+ );
+
+ // Always push layer outputs for visualization (including output layer)
+ this.layerOutputs.push(layerTextures[i + 1]);
+ }
+
+ const modeBuffer = this.device.createBuffer({
+ size: 4,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
+ });
+ this.device.queue.writeBuffer(modeBuffer, 0, new Uint32Array([this.viewMode]));
+
+ // Store result texture and display pipeline for view mode switching
+ this.resultTexture = srcTex;
+ this.displayPipeline = displayPipeline;
+ this.modeBuffer = modeBuffer;
+
+ const displayBG = this.device.createBindGroup({
+ layout: displayPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: srcTex.createView() },
+ { binding: 1, resource: this.inputTexture.createView() },
+ { binding: 2, resource: { buffer: modeBuffer } }
+ ]
+ });
+ this.displayBindGroup = displayBG;
+
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: this.context.getCurrentTexture().createView(),
+ loadOp: 'clear',
+ storeOp: 'store'
+ }]
+ });
+ renderPass.setPipeline(displayPipeline);
+ renderPass.setBindGroup(0, displayBG);
+ renderPass.draw(6);
+ renderPass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+
+ // Wait for GPU to finish before visualizing layers
+ await this.device.queue.onSubmittedWorkDone();
+
+ const t1 = performance.now();
+ const mode = ['CNN Output', 'Original', 'Diff (×10)'][this.viewMode];
+ this.setStatus(`GPU: ${(t1-t0).toFixed(1)}ms | ${width}×${height} | ${mode}`);
+ this.log(`Completed in ${(t1-t0).toFixed(1)}ms`);
+
+ // Update layer visualization panel
+ this.updateLayerVizPanel();
+ }
+
+ updateLayerVizPanel() {
+ const panel = document.getElementById('layerViz');
+
+ if (!this.layerOutputs || this.layerOutputs.length === 0) {
+ panel.innerHTML = '<p style="color: #808080; text-align: center;">No layers to visualize</p>';
+ return;
+ }
+
+ // Only rebuild panel structure if layer count changed
+ const needsRebuild = !this.lastLayerCount || this.lastLayerCount !== this.layerOutputs.length;
+
+ if (needsRebuild) {
+ let html = '<div class="layer-buttons">';
+ html += `<button onclick="tester.visualizeLayer(0, 0)" id="layerBtn0_0">Static 0-3</button>`;
+ html += `<button onclick="tester.visualizeLayer(0, 4)" id="layerBtn0_4">Static 4-7</button>`;
+
+ for (let i = 1; i < this.layerOutputs.length; i++) {
+ const label = `Layer ${i - 1}`;
+ html += `<button onclick="tester.visualizeLayer(${i})" id="layerBtn${i}">${label}</button>`;
+ }
+ html += `<button onclick="tester.saveCompositedLayer()" style="margin-left: 20px; background: #28a745;">Save Composited</button>`;
+ html += '</div>';
+
+ html += '<div class="layer-grid" id="layerGrid"></div>';
+ html += '<div class="layer-preview"><div class="layer-view-label" id="previewLabel">Ch0</div><canvas id="previewCanvas"></canvas></div>';
+
+ panel.innerHTML = html;
+ this.log(`Layer visualization ready: ${this.layerOutputs.length} layers`);
+ this.recreateCanvases();
+ this.lastLayerCount = this.layerOutputs.length;
+ }
+
+ // Update current visualization
+ if (this.currentLayerIdx !== null) {
+ this.visualizeLayer(this.currentLayerIdx, this.currentChannelOffset || 0);
+ } else {
+ this.visualizeLayer(0, 0);
+ }
+ }
+
+ recreateCanvases() {
+ const grid = document.getElementById('layerGrid');
+ if (!grid) return;
+
+ // Force removal of old canvases to clear any WebGPU contexts
+ const oldCanvases = grid.querySelectorAll('canvas');
+ oldCanvases.forEach(canvas => {
+ canvas.width = 0;
+ canvas.height = 0;
+ });
+
+ grid.innerHTML = '';
+ for (let c = 0; c < 4; c++) {
+ const div = document.createElement('div');
+ div.className = 'layer-view';
+ div.innerHTML = `
+ <div class="layer-view-label" id="channelLabel${c}">Ch ${c}</div>
+ <canvas id="layerCanvas${c}"></canvas>
+ `;
+ div.onclick = () => this.selectChannel(c);
+ grid.appendChild(div);
+ }
+ this.selectedChannel = 0;
+ }
+
+ async visualizeLayer(layerIdx, channelOffset = 0) {
+ if (!this.layerOutputs || layerIdx >= this.layerOutputs.length) {
+ this.log(`Cannot visualize layer ${layerIdx}: no data`, 'error');
+ return;
+ }
+
+ // Store current selection
+ this.currentLayerIdx = layerIdx;
+ this.currentChannelOffset = channelOffset;
+
+ // Update button states
+ document.querySelectorAll('.layer-buttons button').forEach(btn => btn.classList.remove('active'));
+ if (layerIdx === 0) {
+ // Static features
+ const btnId = `layerBtn0_${channelOffset}`;
+ const btn = document.getElementById(btnId);
+ if (btn) btn.classList.add('active');
+ } else {
+ const btn = document.getElementById(`layerBtn${layerIdx}`);
+ if (btn) btn.classList.add('active');
+ }
+
+ const layerName = layerIdx === 0 ? `Static Features (${channelOffset}-${channelOffset + 3})` : `Layer ${layerIdx - 1}`;
+ const layerTex = this.layerOutputs[layerIdx];
+ const { width, height } = this.getDimensions();
+
+ // Update channel labels based on layer type
+ // Static features (layerIdx=0): 8 channels split into two views
+ // CNN layers (layerIdx≥1): 4 channels per layer
+ const staticLabels = [
+ ['Ch0 (p0)', 'Ch1 (p1)', 'Ch2 (p2)', 'Ch3 (p3)'],
+ ['Ch4 (uv_x)', 'Ch5 (uv_y)', 'Ch6 (sin10_x)', 'Ch7 (bias)']
+ ];
+ const channelLabels = layerIdx === 0
+ ? staticLabels[channelOffset / 4]
+ : ['Ch0', 'Ch1', 'Ch2', 'Ch3'];
+
+ for (let c = 0; c < 4; c++) {
+ const label = document.getElementById(`channelLabel${c}`);
+ if (label) label.textContent = channelLabels[c];
+ }
+
+ // Create layer viz pipeline if needed
+ if (!this.layerVizPipeline) {
+ this.layerVizPipeline = this.device.createRenderPipeline({
+ layout: 'auto',
+ vertex: {
+ module: this.device.createShaderModule({ code: LAYER_VIZ_SHADER }),
+ entryPoint: 'vs_main'
+ },
+ fragment: {
+ module: this.device.createShaderModule({ code: LAYER_VIZ_SHADER }),
+ entryPoint: 'fs_main',
+ targets: [{ format: this.format }]
+ }
+ });
+ this.log('Created layer visualization pipeline');
+ }
+
+ // Render each channel to its canvas
+ for (let c = 0; c < 4; c++) {
+ const canvas = document.getElementById(`layerCanvas${c}`);
+ if (!canvas) {
+ this.log(`Canvas layerCanvas${c} not found`, 'error');
+ continue;
+ }
+
+ // Set canvas size BEFORE getting context
+ canvas.width = width;
+ canvas.height = height;
+
+ const ctx = canvas.getContext('webgpu');
+ if (!ctx) {
+ this.log(`Failed to get WebGPU context for channel ${c}`, 'error');
+ continue;
+ }
+
+ try {
+ ctx.configure({ device: this.device, format: this.format });
+ } catch (e) {
+ this.log(`Failed to configure canvas ${c}: ${e.message}`, 'error');
+ continue;
+ }
+
+ const vizScale = 1.0; // Always 1.0, shader clamps to [0,1]
+ const paramsBuffer = this.device.createBuffer({
+ size: 8,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
+ });
+ // Use channel index with offset for static features
+ const actualChannel = channelOffset + c;
+ const paramsData = new Float32Array([actualChannel, vizScale]);
+ this.device.queue.writeBuffer(paramsBuffer, 0, paramsData);
+
+ const bindGroup = this.device.createBindGroup({
+ layout: this.layerVizPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: layerTex.createView() },
+ { binding: 1, resource: { buffer: paramsBuffer } }
+ ]
+ });
+
+ const encoder = this.device.createCommandEncoder();
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: ctx.getCurrentTexture().createView(),
+ loadOp: 'clear',
+ clearValue: { r: 1.0, g: 0.0, b: 1.0, a: 1.0 }, // Magenta clear for debugging
+ storeOp: 'store'
+ }]
+ });
+
+ renderPass.setPipeline(this.layerVizPipeline);
+ renderPass.setBindGroup(0, bindGroup);
+ renderPass.draw(6);
+ renderPass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+
+ // Wait for all renders to complete
+ await this.device.queue.onSubmittedWorkDone();
+
+ // Update active channel highlighting and preview
+ this.updateChannelSelection();
+ await this.renderChannelPreview();
+ }
+
+ selectChannel(channelIdx) {
+ this.selectedChannel = channelIdx;
+ this.updateChannelSelection();
+ this.renderChannelPreview();
+ }
+
+ updateChannelSelection() {
+ const grid = document.getElementById('layerGrid');
+ if (!grid) return;
+
+ const views = grid.querySelectorAll('.layer-view');
+ views.forEach((view, idx) => {
+ view.classList.toggle('active', idx === this.selectedChannel);
+ });
+ }
+
+ async renderChannelPreview() {
+ const previewCanvas = document.getElementById('previewCanvas');
+ const previewLabel = document.getElementById('previewLabel');
+ if (!previewCanvas || !this.device) return;
+
+ const { width, height } = this.getDimensions();
+ previewCanvas.width = width;
+ previewCanvas.height = height;
+
+ const ctx = previewCanvas.getContext('webgpu');
+ if (!ctx) return;
+
+ try {
+ ctx.configure({ device: this.device, format: this.format });
+ } catch (e) {
+ return;
+ }
+
+ // Update label
+ const channelLabel = document.getElementById(`channelLabel${this.selectedChannel}`);
+ if (channelLabel && previewLabel) {
+ previewLabel.textContent = channelLabel.textContent;
+ }
+
+ // Render selected channel
+ const layerIdx = this.currentLayerIdx;
+ const channelOffset = this.currentChannelOffset;
+ const layerTex = this.layerOutputs[layerIdx];
+ if (!layerTex) return;
+
+ // Always 1.0, shader clamps to [0,1] - show exact layer values
+ const vizScale = 1.0;
+ const actualChannel = channelOffset + this.selectedChannel;
+
+ const paramsBuffer = this.device.createBuffer({
+ size: 8,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
+ });
+ const paramsData = new Float32Array([actualChannel, vizScale]);
+ this.device.queue.writeBuffer(paramsBuffer, 0, paramsData);
+
+ const bindGroup = this.device.createBindGroup({
+ layout: this.layerVizPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: layerTex.createView() },
+ { binding: 1, resource: { buffer: paramsBuffer } }
+ ]
+ });
+
+ const encoder = this.device.createCommandEncoder();
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: ctx.getCurrentTexture().createView(),
+ loadOp: 'clear',
+ storeOp: 'store'
+ }]
+ });
+
+ renderPass.setPipeline(this.layerVizPipeline);
+ renderPass.setBindGroup(0, bindGroup);
+ renderPass.draw(6);
+ renderPass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+
+ visualizeWeights(cnnLayerIdx) {
+ const layer = this.weights.layers[cnnLayerIdx];
+ if (!layer) {
+ this.log(`Layer ${cnnLayerIdx} not found`, 'error');
+ return;
+ }
+
+ // Update button states
+ document.querySelectorAll('#weightsLayerButtons button').forEach(btn => btn.classList.remove('active'));
+ const btn = document.getElementById(`weightsBtn${cnnLayerIdx}`);
+ if (btn) btn.classList.add('active');
+
+ const { kernelSize, inChannels, outChannels, weightOffset, min, max } = layer;
+
+ const canvas = document.getElementById('weightsCanvas');
+ const ctx = canvas.getContext('2d', { willReadFrequently: false });
+
+ // 1 pixel per weight, show all input channels horizontally
+ const width = inChannels * kernelSize;
+ const height = outChannels * kernelSize;
+
+ canvas.width = width;
+ canvas.height = height;
+
+ ctx.fillStyle = '#1a1a1a';
+ ctx.fillRect(0, 0, width, height);
+
+ // Stack output channels vertically
+ for (let outCh = 0; outCh < outChannels; outCh++) {
+ const yOffset = outCh * kernelSize;
+
+ for (let inCh = 0; inCh < inChannels; inCh++) {
+ const xOffset = inCh * kernelSize;
+
+ for (let ky = 0; ky < kernelSize; ky++) {
+ for (let kx = 0; kx < kernelSize; kx++) {
+ const spatialIdx = ky * kernelSize + kx;
+ const wIdx = weightOffset +
+ outCh * inChannels * kernelSize * kernelSize +
+ inCh * kernelSize * kernelSize +
+ spatialIdx;
+
+ const weight = this.getWeightValue(wIdx);
+ const normalized = (weight - min) / (max - min);
+ const intensity = Math.floor(normalized * 255);
+
+ ctx.fillStyle = `rgb(${intensity}, ${intensity}, ${intensity})`;
+ ctx.fillRect(xOffset + kx, yOffset + ky, 1, 1);
+ }
+ }
+ }
+ }
+ }
+
+ getWeightValue(idx) {
+ const pairIdx = Math.floor(idx / 2);
+ const packed = this.weights.weights[pairIdx];
+ const unpacked = this.unpackF16(packed);
+ return (idx % 2 === 0) ? unpacked[0] : unpacked[1];
+ }
+
+ toggleWeightsInfo() {
+ const panel = document.getElementById('weightsInfoPanel');
+ const toggle = document.getElementById('weightsInfoToggle');
+ panel.classList.toggle('collapsed');
+ toggle.textContent = panel.classList.contains('collapsed') ? '▶' : '▼';
+ }
+
+ updateDisplay() {
+ if (!this.displayPipeline || !this.displayBindGroup) return;
+
+ this.device.queue.writeBuffer(this.modeBuffer, 0, new Uint32Array([this.viewMode]));
+
+ const encoder = this.device.createCommandEncoder();
+ const renderPass = encoder.beginRenderPass({
+ colorAttachments: [{
+ view: this.context.getCurrentTexture().createView(),
+ loadOp: 'clear',
+ storeOp: 'store'
+ }]
+ });
+ renderPass.setPipeline(this.displayPipeline);
+ renderPass.setBindGroup(0, this.displayBindGroup);
+ renderPass.draw(6);
+ renderPass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+
+ async savePNG() {
+ if (!this.image && !this.isVideo) {
+ this.log('No image loaded', 'error');
+ return;
+ }
+
+ if (!this.resultTexture) {
+ this.log('No result to save', 'error');
+ return;
+ }
+
+ try {
+ const { width, height } = this.getDimensions();
+
+ // GPU readback from result texture
+ const bytesPerRow = width * 16; // 4×u32 per pixel
+ const paddedBytesPerRow = Math.ceil(bytesPerRow / 256) * 256;
+ const bufferSize = paddedBytesPerRow * height;
+
+ const stagingBuffer = this.device.createBuffer({
+ size: bufferSize,
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
+ });
+
+ const encoder = this.device.createCommandEncoder();
+ encoder.copyTextureToBuffer(
+ { texture: this.resultTexture },
+ { buffer: stagingBuffer, bytesPerRow: paddedBytesPerRow, rowsPerImage: height },
+ { width, height, depthOrArrayLayers: 1 }
+ );
+ this.device.queue.submit([encoder.finish()]);
+
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
+ const mapped = new Uint8Array(stagingBuffer.getMappedRange());
+
+ // Unpack f16 to RGBA8
+ const pixels = new Uint8Array(width * height * 4);
+ for (let y = 0; y < height; y++) {
+ const rowOffset = y * paddedBytesPerRow;
+ for (let x = 0; x < width; x++) {
+ const pixelOffset = rowOffset + x * 16;
+ const data = new Uint32Array(mapped.buffer, mapped.byteOffset + pixelOffset, 4);
+
+ // Unpack f16 (first 4 channels only)
+ const unpack = (u32, idx) => {
+ const h = (idx === 0) ? (u32 & 0xFFFF) : ((u32 >> 16) & 0xFFFF);
+ const sign = (h >> 15) & 1;
+ const exp = (h >> 10) & 0x1F;
+ const frac = h & 0x3FF;
+ if (exp === 0) return 0;
+ if (exp === 31) return sign ? 0 : 255;
+ const e = exp - 15;
+ const val = (1 + frac / 1024) * Math.pow(2, e);
+ return Math.max(0, Math.min(255, Math.round(val * 255)));
+ };
+
+ const outIdx = (y * width + x) * 4;
+ pixels[outIdx + 0] = unpack(data[0], 0); // R
+ pixels[outIdx + 1] = unpack(data[0], 1); // G
+ pixels[outIdx + 2] = unpack(data[1], 0); // B
+ pixels[outIdx + 3] = 255; // A
+ }
+ }
+
+ stagingBuffer.unmap();
+ stagingBuffer.destroy();
+
+ // Create blob from pixels
+ const canvas = document.createElement('canvas');
+ canvas.width = width;
+ canvas.height = height;
+ const ctx = canvas.getContext('2d');
+ const imageData = new ImageData(new Uint8ClampedArray(pixels), width, height);
+ ctx.putImageData(imageData, 0, 0);
+
+ const blob = await new Promise(resolve => canvas.toBlob(resolve, 'image/png'));
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ const mode = ['cnn', 'original', 'diff'][this.viewMode];
+ a.href = url;
+ a.download = `output_${width}x${height}_${mode}.png`;
+ a.click();
+ URL.revokeObjectURL(url);
+
+ this.log(`Saved PNG: ${a.download}`);
+ this.setStatus(`Saved: ${a.download}`);
+ } catch (err) {
+ this.log(`Failed to save PNG: ${err.message}`, 'error');
+ this.setStatus(`Save failed: ${err.message}`, true);
+ }
+ }
+
+ async saveCompositedLayer() {
+ if (!this.currentLayerIdx) {
+ this.log('No layer selected for compositing', 'error');
+ return;
+ }
+
+ try {
+ const canvases = [];
+ for (let i = 0; i < 4; i++) {
+ const canvas = document.getElementById(`layerCanvas${i}`);
+ if (!canvas) {
+ this.log(`Canvas layerCanvas${i} not found`, 'error');
+ return;
+ }
+ canvases.push(canvas);
+ }
+
+ const width = canvases[0].width;
+ const height = canvases[0].height;
+ const compositedWidth = width * 4;
+
+ // Create composited canvas
+ const compositedCanvas = document.createElement('canvas');
+ compositedCanvas.width = compositedWidth;
+ compositedCanvas.height = height;
+ const ctx = compositedCanvas.getContext('2d');
+
+ // Composite horizontally
+ for (let i = 0; i < 4; i++) {
+ ctx.drawImage(canvases[i], i * width, 0);
+ }
+
+ // Convert to grayscale
+ const imageData = ctx.getImageData(0, 0, compositedWidth, height);
+ const pixels = imageData.data;
+ for (let i = 0; i < pixels.length; i += 4) {
+ const gray = 0.299 * pixels[i] + 0.587 * pixels[i + 1] + 0.114 * pixels[i + 2];
+ pixels[i] = pixels[i + 1] = pixels[i + 2] = gray;
+ }
+ ctx.putImageData(imageData, 0, 0);
+
+ // Save as PNG
+ const blob = await new Promise(resolve => compositedCanvas.toBlob(resolve, 'image/png'));
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.href = url;
+ a.download = `composited_layer${this.currentLayerIdx - 1}_${compositedWidth}x${height}.png`;
+ a.click();
+ URL.revokeObjectURL(url);
+
+ this.log(`Saved composited layer: ${a.download}`);
+ this.setStatus(`Saved: ${a.download}`);
+ } catch (err) {
+ this.log(`Failed to save composited layer: ${err.message}`, 'error');
+ this.setStatus(`Compositing failed: ${err.message}`, true);
+ }
+ }
+}
+
+const tester = new CNNTester();
+
+// Load default weights on startup
+(async () => {
+ try {
+ const binaryString = atob(DEFAULT_WEIGHTS_B64);
+ const bytes = new Uint8Array(binaryString.length);
+ for (let i = 0; i < binaryString.length; i++) {
+ bytes[i] = binaryString.charCodeAt(i);
+ }
+ await tester.loadWeights({ name: 'default.bin', arrayBuffer: () => Promise.resolve(bytes.buffer) });
+ tester.log('Loaded default weights');
+ } catch (err) {
+ tester.log(`Failed to load default weights: ${err.message}`, 'error');
+ }
+})();
+
+function setupDropZone(id, callback) {
+ const zone = document.getElementById(id);
+ ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(e => {
+ zone.addEventListener(e, ev => { ev.preventDefault(); ev.stopPropagation(); });
+ });
+ ['dragenter', 'dragover'].forEach(e => zone.addEventListener(e, () => zone.classList.add('active')));
+ ['dragleave', 'drop'].forEach(e => zone.addEventListener(e, () => zone.classList.remove('active')));
+ zone.addEventListener('drop', e => {
+ const file = e.dataTransfer.files[0];
+ if (file) callback(file).catch(err => {
+ zone.classList.add('error');
+ tester.setStatus(err.message, true);
+ tester.log(err.message, 'error');
+ setTimeout(() => zone.classList.remove('error'), 2000);
+ });
+ });
+}
+
+// Whole window drop for PNG images and videos
+const mainArea = document.getElementById('mainDrop');
+['dragenter', 'dragover', 'dragleave', 'drop'].forEach(e => {
+ mainArea.addEventListener(e, ev => { ev.preventDefault(); ev.stopPropagation(); });
+});
+['dragenter', 'dragover'].forEach(e => mainArea.addEventListener(e, () => mainArea.classList.add('drop-active')));
+['dragleave', 'drop'].forEach(e => mainArea.addEventListener(e, () => mainArea.classList.remove('drop-active')));
+mainArea.addEventListener('drop', e => {
+ const file = e.dataTransfer.files[0];
+ if (file) {
+ if (file.type.startsWith('image/')) {
+ tester.loadImage(file).catch(err => {
+ tester.setStatus(err.message, true);
+ tester.log(err.message, 'error');
+ });
+ } else if (file.type.startsWith('video/')) {
+ tester.loadVideo(file).catch(err => {
+ tester.setStatus(err.message, true);
+ tester.log(err.message, 'error');
+ });
+ }
+ }
+});
+
+// Weights drop zone
+setupDropZone('weightsDrop', f => tester.loadWeights(f));
+
+// Weights file input
+document.getElementById('weightsFile').addEventListener('change', e => {
+ const file = e.target.files[0];
+ if (file) {
+ tester.loadWeights(file).catch(err => {
+ tester.setStatus(err.message, true);
+ tester.log(err.message, 'error');
+ });
+ }
+});
+
+document.getElementById('blend').addEventListener('input', e => {
+ tester.blendAmount = parseFloat(e.target.value);
+ document.getElementById('blendValue').textContent = e.target.value;
+ if ((tester.image || tester.isVideo) && tester.weights) {
+ tester.log(`Blend changed to ${e.target.value}`);
+ tester.run();
+ }
+});
+
+document.getElementById('depth').addEventListener('input', e => {
+ tester.depth = parseFloat(e.target.value);
+ document.getElementById('depthValue').textContent = e.target.value;
+ if ((tester.image || tester.isVideo) && tester.weights) tester.run();
+});
+
+document.getElementById('mipLevel').addEventListener('change', e => {
+ tester.mipLevel = parseInt(e.target.value);
+ tester.log(`Mip level changed to ${e.target.value}`);
+ if ((tester.image || tester.isVideo) && tester.weights) tester.run();
+});
+
+document.getElementById('playPauseBtn').addEventListener('click', () => tester.togglePlayPause());
+document.getElementById('stepBackBtn').addEventListener('click', () => tester.stepFrame(-1));
+document.getElementById('stepForwardBtn').addEventListener('click', () => tester.stepFrame(1));
+document.getElementById('savePngBtn').addEventListener('click', () => tester.savePNG());
+
+document.addEventListener('keydown', e => {
+ if (e.code === 'Space') {
+ e.preventDefault();
+ if (tester.viewMode === 1) {
+ tester.viewMode = 0;
+ } else {
+ tester.viewMode = 1;
+ }
+ const modeName = ['CNN Output', 'Original', 'Diff (×10)'][tester.viewMode];
+ if ((tester.image || tester.isVideo) && tester.weights) {
+ tester.log(`View mode: ${modeName}`);
+ tester.updateDisplay();
+ const width = tester.isVideo ? tester.video.videoWidth : tester.image.width;
+ const height = tester.isVideo ? tester.video.videoHeight : tester.image.height;
+ tester.setStatus(`${width}×${height} | ${modeName}`);
+ }
+ } else if (e.code === 'KeyD') {
+ e.preventDefault();
+ if (tester.viewMode === 2) {
+ tester.viewMode = 0;
+ } else {
+ tester.viewMode = 2;
+ }
+ const modeName = ['CNN Output', 'Original', 'Diff (×10)'][tester.viewMode];
+ if ((tester.image || tester.isVideo) && tester.weights) {
+ tester.log(`View mode: ${modeName}`);
+ tester.updateDisplay();
+ const width = tester.isVideo ? tester.video.videoWidth : tester.image.width;
+ const height = tester.isVideo ? tester.video.videoHeight : tester.image.height;
+ tester.setStatus(`${width}×${height} | ${modeName}`);
+ }
+ }
+});
+ </script>
+</body>
+</html>
diff --git a/cnn_v2/training/export_cnn_v2_shader.py b/cnn_v2/training/export_cnn_v2_shader.py
new file mode 100755
index 0000000..8692a62
--- /dev/null
+++ b/cnn_v2/training/export_cnn_v2_shader.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture
+
+Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
+Generates one shader per layer with embedded weight arrays.
+
+Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
+ This script is for debugging/testing with per-layer shaders.
+"""
+
+import argparse
+import numpy as np
+import torch
+from pathlib import Path
+
+# Path resolution for running from any directory
+SCRIPT_DIR = Path(__file__).parent
+PROJECT_ROOT = SCRIPT_DIR.parent.parent
+
+
+def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False):
+ """Generate WGSL compute shader for a single CNN layer.
+
+ Args:
+ layer_idx: Layer index (0, 1, 2, ...)
+ weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
+ kernel_size: Kernel size (3, 5, etc.)
+ output_dir: Output directory path
+ mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.)
+ is_output_layer: True if this is the final RGBA output layer
+ """
+ weights_flat = weights.flatten()
+ weights_f16 = weights_flat.astype(np.float16)
+ weights_f32 = weights_f16.astype(np.float32) # WGSL stores as f32 literals
+
+ # Format weights as WGSL array
+ weights_str = ",\n ".join(
+ ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
+ for i in range(0, len(weights_f32), 8)
+ )
+
+ radius = kernel_size // 2
+ if is_output_layer:
+ activation = "output[c] = clamp(sum, 0.0, 1.0); // Output layer"
+ elif layer_idx == 0:
+ activation = "output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1]"
+ else:
+ activation = "output[c] = max(0.0, sum); // Middle layers: ReLU"
+
+ shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
+// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
+// Mip level: {mip_level} (p0-p3 features)
+
+const KERNEL_SIZE: u32 = {kernel_size}u;
+const IN_CHANNELS: u32 = 12u; // 4 (input/prev) + 8 (static)
+const OUT_CHANNELS: u32 = 4u; // Uniform output
+const KERNEL_RADIUS: i32 = {radius};
+
+// Weights quantized to float16 (stored as f32 in WGSL)
+const weights: array<f32, {len(weights_f32)}> = array(
+ {weights_str}
+);
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
+}}
+
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
+ );
+}}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
+ return;
+ }}
+
+ // Load static features (always available)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {{
+ var sum: f32 = 0.0;
+
+ for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
+ for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load features at this spatial location
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped); // 4D
+
+ // Weight index calculation
+ let ky_idx = u32(ky + KERNEL_RADIUS);
+ let kx_idx = u32(kx + KERNEL_RADIUS);
+ let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;
+
+ // Accumulate: previous/input channels (4D)
+ for (var i: u32 = 0u; i < 4u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
+ i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * layer_local[i];
+ }}
+
+ // Accumulate: static features (8D)
+ for (var i: u32 = 0u; i < 8u; i++) {{
+ let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
+ (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
+ sum += weights[w_idx] * static_local[i];
+ }}
+ }}
+ }}
+
+ {activation}
+ }}
+
+ // Pack and store
+ textureStore(output_tex, coord, pack_channels(output));
+}}
+"""
+
+ output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def export_checkpoint(checkpoint_path, output_dir):
+ """Export PyTorch checkpoint to WGSL shaders.
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_dir: Output directory for shaders
+ """
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ kernel_size = config.get('kernel_size', 3)
+ num_layers = config.get('num_layers', 3)
+ mip_level = config.get('mip_level', 0)
+
+ print(f"Configuration:")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Layers: {num_layers}")
+ print(f" Mip level: {mip_level} (p0-p3 features)")
+ print(f" Architecture: uniform 12D→4D")
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nExporting shaders to {output_dir}/")
+
+ # All layers uniform: 12D→4D
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
+
+ layer_weights = state_dict[layer_key].detach().numpy()
+ is_output = (i == num_layers - 1)
+
+ export_layer_shader(
+ layer_idx=i,
+ weights=layer_weights,
+ kernel_size=kernel_size,
+ output_dir=output_dir,
+ mip_level=mip_level,
+ is_output_layer=is_output
+ )
+
+ print(f"\nExport complete! Generated {num_layers} shader files.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-dir', type=str, default=str(PROJECT_ROOT / 'workspaces/main/shaders'),
+ help='Output directory for shaders')
+
+ args = parser.parse_args()
+ export_checkpoint(args.checkpoint, args.output_dir)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cnn_v2/training/export_cnn_v2_weights.py b/cnn_v2/training/export_cnn_v2_weights.py
new file mode 100755
index 0000000..d66b980
--- /dev/null
+++ b/cnn_v2/training/export_cnn_v2_weights.py
@@ -0,0 +1,288 @@
+#!/usr/bin/env python3
+"""CNN v2 Weight Export Script
+
+Converts PyTorch checkpoints to binary weight format for storage buffer.
+Exports single shader template + binary weights asset.
+"""
+
+import argparse
+import numpy as np
+import torch
+import struct
+from pathlib import Path
+
+# Path resolution for running from any directory
+SCRIPT_DIR = Path(__file__).parent
+PROJECT_ROOT = SCRIPT_DIR.parent.parent
+
+
+def export_weights_binary(checkpoint_path, output_path, quiet=False):
+ """Export CNN v2 weights to binary format.
+
+ Binary format:
+ Header (20 bytes):
+ uint32 magic ('CNN2')
+ uint32 version (2)
+ uint32 num_layers
+ uint32 total_weights (f16 count)
+ uint32 mip_level (0-3)
+
+ LayerInfo × num_layers (20 bytes each):
+ uint32 kernel_size
+ uint32 in_channels
+ uint32 out_channels
+ uint32 weight_offset (f16 index)
+ uint32 weight_count
+
+ Weights (f16 array):
+ float16[] all_weights
+
+ Args:
+ checkpoint_path: Path to .pth checkpoint
+ output_path: Output .bin file path
+
+ Returns:
+ config dict for shader generation
+ """
+ if not quiet:
+ print(f"Loading checkpoint: {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ state_dict = checkpoint['model_state_dict']
+ config = checkpoint['config']
+
+ # Support both old (kernel_size) and new (kernel_sizes) format
+ if 'kernel_sizes' in config:
+ kernel_sizes = config['kernel_sizes']
+ elif 'kernel_size' in config:
+ kernel_size = config['kernel_size']
+ num_layers = config.get('num_layers', 3)
+ kernel_sizes = [kernel_size] * num_layers
+ else:
+ kernel_sizes = [3, 3, 3] # fallback
+
+ num_layers = config.get('num_layers', len(kernel_sizes))
+ mip_level = config.get('mip_level', 0)
+
+ if not quiet:
+ print(f"Configuration:")
+ print(f" Kernel sizes: {kernel_sizes}")
+ print(f" Layers: {num_layers}")
+ print(f" Mip level: {mip_level} (p0-p3 features)")
+ print(f" Architecture: uniform 12D→4D (bias=False)")
+
+ # Collect layer info - all layers uniform 12D→4D
+ layers = []
+ all_weights = []
+ weight_offset = 0
+
+ for i in range(num_layers):
+ layer_key = f'layers.{i}.weight'
+ if layer_key not in state_dict:
+ raise ValueError(f"Missing weights for layer {i}: {layer_key}")
+
+ layer_weights = state_dict[layer_key].detach().numpy()
+ layer_flat = layer_weights.flatten()
+ kernel_size = kernel_sizes[i]
+
+ layers.append({
+ 'kernel_size': kernel_size,
+ 'in_channels': 12, # 4 (input/prev) + 8 (static)
+ 'out_channels': 4, # Uniform output
+ 'weight_offset': weight_offset,
+ 'weight_count': len(layer_flat)
+ })
+ all_weights.extend(layer_flat)
+ weight_offset += len(layer_flat)
+
+ if not quiet:
+ print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights")
+
+ # Convert to f16
+ # TODO: Use 8-bit quantization for 2× size reduction
+ # Requires quantization-aware training (QAT) to maintain accuracy
+ all_weights_f16 = np.array(all_weights, dtype=np.float16)
+
+ # Pack f16 pairs into u32 for storage buffer
+ # Pad to even count if needed
+ if len(all_weights_f16) % 2 == 1:
+ all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))
+
+ # Pack pairs using numpy view
+ weights_u32 = all_weights_f16.view(np.uint32)
+
+ binary_size = 20 + len(layers) * 20 + len(weights_u32) * 4
+ if not quiet:
+ print(f"\nWeight statistics:")
+ print(f" Total layers: {len(layers)}")
+ print(f" Total weights: {len(all_weights_f16)} (f16)")
+ print(f" Packed: {len(weights_u32)} u32")
+ print(f" Binary size: {binary_size} bytes")
+
+ # Write binary file
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, 'wb') as f:
+ # Header (20 bytes) - version 2 with mip_level
+ f.write(struct.pack('<4sIIII',
+ b'CNN2', # magic
+ 2, # version (bumped to 2)
+ len(layers), # num_layers
+ len(all_weights_f16), # total_weights (f16 count)
+ mip_level)) # mip_level
+
+ # Layer info (20 bytes per layer)
+ for layer in layers:
+ f.write(struct.pack('<IIIII',
+ layer['kernel_size'],
+ layer['in_channels'],
+ layer['out_channels'],
+ layer['weight_offset'],
+ layer['weight_count']))
+
+ # Weights (u32 packed f16 pairs)
+ f.write(weights_u32.tobytes())
+
+ if quiet:
+ print(f" Exported {num_layers} layers, {len(all_weights_f16)} weights, {binary_size} bytes → {output_path}")
+ else:
+ print(f" → {output_path}")
+
+ return {
+ 'num_layers': len(layers),
+ 'layers': layers
+ }
+
+
+def export_shader_template(config, output_dir):
+ """Generate single WGSL shader template with storage buffer binding.
+
+ Args:
+ config: Layer configuration from export_weights_binary()
+ output_dir: Output directory path
+ """
+ shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
+// Reads weights from storage buffer, processes all layers in sequence
+
+struct CNNv2Header {
+ magic: u32, // 'CNN2'
+ version: u32, // 1
+ num_layers: u32, // Number of layers
+ total_weights: u32, // Total f16 weight count
+}
+
+struct CNNv2LayerInfo {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in weights array
+ weight_count: u32,
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>;
+@group(0) @binding(1) var layer_input: texture_2d<u32>;
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
+@group(0) @binding(3) var<storage, read> weights: array<u32>; // Packed f16 pairs
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
+}
+
+fn pack_channels(values: vec4<f32>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values.x, values.y)),
+ pack2x16float(vec2<f32>(values.z, values.w)),
+ 0u, // Unused
+ 0u // Unused
+ );
+}
+
+fn get_weight(idx: u32) -> f32 {
+ let pair_idx = idx / 2u;
+ let packed = weights[8u + pair_idx]; // Skip header (32 bytes = 8 u32)
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ // Read header
+ let header_packed = weights[0]; // magic + version
+ let counts_packed = weights[1]; // num_layers + total_weights
+ let num_layers = counts_packed & 0xFFFFu;
+
+ // Load static features
+ let static_feat = unpack_static_features(coord);
+
+ // Process each layer (hardcoded for 3 layers for now)
+ // TODO: Dynamic layer loop when needed
+
+ // Example for layer 0 - expand to full multi-layer when tested
+ let layer_info_offset = 2u; // After header
+ let layer0_info_base = layer_info_offset;
+
+ // Read layer 0 info (5 u32 values = 20 bytes)
+ let kernel_size = weights[layer0_info_base];
+ let in_channels = weights[layer0_info_base + 1u];
+ let out_channels = weights[layer0_info_base + 2u];
+ let weight_offset = weights[layer0_info_base + 3u];
+
+ // Convolution: 12D input (4 prev + 8 static) → 4D output
+ var output: vec4<f32> = vec4<f32>(0.0);
+ for (var c: u32 = 0u; c < 4u; c++) {
+ output[c] = 0.0; // TODO: Actual convolution
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}
+"""
+
+ output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl"
+ output_path.write_text(shader_code)
+ print(f" → {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
+ parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
+ parser.add_argument('--output-weights', type=str, default=str(PROJECT_ROOT / 'workspaces/main/weights/cnn_v2_weights.bin'),
+ help='Output binary weights file')
+ parser.add_argument('--output-shader', type=str, default=str(PROJECT_ROOT / 'workspaces/main/shaders'),
+ help='Output directory for shader template')
+ parser.add_argument('--quiet', action='store_true',
+ help='Suppress detailed output')
+
+ args = parser.parse_args()
+
+ if not args.quiet:
+ print("=== CNN v2 Weight Export ===\n")
+ config = export_weights_binary(args.checkpoint, args.output_weights, quiet=args.quiet)
+ if not args.quiet:
+ print()
+ # Shader is manually maintained in cnn_v2_compute.wgsl
+ # export_shader_template(config, args.output_shader)
+ print("\nExport complete!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cnn_v2/training/gen_identity_weights.py b/cnn_v2/training/gen_identity_weights.py
new file mode 100755
index 0000000..08eecc6
--- /dev/null
+++ b/cnn_v2/training/gen_identity_weights.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python3
+"""Generate Identity CNN v2 Weights
+
+Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough.
+Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features).
+
+With --mix: Output Ch{i} = 0.5*prev[i] + 0.5*static_p{4+i}
+ (50-50 blend of prev layer with uv_x, uv_y, sin20_y, bias)
+
+With --p47: Output Ch{i} = static p{4+i} (uv_x, uv_y, sin20_y, bias)
+ (p4/uv_x→ch0, p5/uv_y→ch1, p6/sin20_y→ch2, p7/bias→ch3)
+
+Usage:
+ ./training/gen_identity_weights.py [output.bin]
+ ./training/gen_identity_weights.py --mix [output.bin]
+ ./training/gen_identity_weights.py --p47 [output.bin]
+"""
+
+import argparse
+import numpy as np
+import struct
+from pathlib import Path
+
+# Path resolution for running from any directory
+SCRIPT_DIR = Path(__file__).parent
+PROJECT_ROOT = SCRIPT_DIR.parent.parent
+
+
+def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False):
+ """Generate identity weights: output = input (ignores static features).
+
+ If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow).
+ If p47=True, transfers static p4-p7 (uv_x, uv_y, sin20_y, bias) to output channels.
+
+ Input channel layout: [0-3: prev layer, 4-11: static (p0-p7)]
+ Static features: p0-p3 (RGB+D), p4 (uv_x), p5 (uv_y), p6 (sin20_y), p7 (bias)
+
+ Binary format:
+ Header (20 bytes):
+ uint32 magic ('CNN2')
+ uint32 version (2)
+ uint32 num_layers (1)
+ uint32 total_weights (f16 count)
+ uint32 mip_level
+
+ LayerInfo (20 bytes):
+ uint32 kernel_size
+ uint32 in_channels (12)
+ uint32 out_channels (4)
+ uint32 weight_offset (0)
+ uint32 weight_count
+
+ Weights (u32 packed f16):
+ Identity matrix for first 4 input channels
+ Zeros for static features (channels 4-11) OR
+ Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True
+ """
+ # Identity: 4 output channels, 12 input channels
+ # Weight shape: [out_ch, in_ch, kernel_h, kernel_w]
+ in_channels = 12 # 4 input + 8 static
+ out_channels = 4
+
+ # Identity matrix: diagonal 1.0 for first 4 channels, 0.0 for rest
+ weights = np.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=np.float32)
+
+ # Center position for kernel
+ center = kernel_size // 2
+
+ if p47:
+ # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only)
+ # Input channels: [0-3: prev layer, 4-11: static features (p0-p7)]
+ # p4-p7 are at input channels 8-11
+ for i in range(out_channels):
+ weights[i, i + 8, center, center] = 1.0
+ elif mix:
+ # Mix mode: 50-50 blend (p0+p4, p1+p5, p2+p6, p3+p7)
+ # p0-p3 are at channels 0-3 (prev layer), p4-p7 at channels 8-11 (static)
+ for i in range(out_channels):
+ weights[i, i, center, center] = 0.5 # 0.5*p{i} (prev layer)
+ weights[i, i + 8, center, center] = 0.5 # 0.5*p{i+4} (static)
+ else:
+ # Identity: output ch i = input ch i
+ for i in range(out_channels):
+ weights[i, i, center, center] = 1.0
+
+ # Flatten
+ weights_flat = weights.flatten()
+ weight_count = len(weights_flat)
+
+ mode_name = 'p47' if p47 else ('mix' if mix else 'identity')
+ print(f"Generating {mode_name} weights:")
+ print(f" Kernel size: {kernel_size}×{kernel_size}")
+ print(f" Channels: 12D→4D")
+ print(f" Weights: {weight_count}")
+ print(f" Mip level: {mip_level}")
+ if mix:
+ print(f" Mode: 0.5*prev[i] + 0.5*static_p{{4+i}} (blend with uv/sin/bias)")
+ elif p47:
+ print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3")
+
+ # Convert to f16
+ weights_f16 = np.array(weights_flat, dtype=np.float16)
+
+ # Pad to even count
+ if len(weights_f16) % 2 == 1:
+ weights_f16 = np.append(weights_f16, np.float16(0.0))
+
+ # Pack f16 pairs into u32
+ weights_u32 = weights_f16.view(np.uint32)
+
+ print(f" Packed: {len(weights_u32)} u32")
+ print(f" Binary size: {20 + 20 + len(weights_u32) * 4} bytes")
+
+ # Write binary
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, 'wb') as f:
+ # Header (20 bytes)
+ f.write(struct.pack('<4sIIII',
+ b'CNN2', # magic
+ 2, # version
+ 1, # num_layers
+ len(weights_f16), # total_weights
+ mip_level)) # mip_level
+
+ # Layer info (20 bytes)
+ f.write(struct.pack('<IIIII',
+ kernel_size, # kernel_size
+ in_channels, # in_channels
+ out_channels, # out_channels
+ 0, # weight_offset
+ weight_count)) # weight_count
+
+ # Weights (u32 packed f16)
+ f.write(weights_u32.tobytes())
+
+ print(f" → {output_path}")
+
+ # Verify
+ print("\nVerification:")
+ with open(output_path, 'rb') as f:
+ data = f.read()
+ magic, version, num_layers, total_weights, mip = struct.unpack('<4sIIII', data[:20])
+ print(f" Magic: {magic}")
+ print(f" Version: {version}")
+ print(f" Layers: {num_layers}")
+ print(f" Total weights: {total_weights}")
+ print(f" Mip level: {mip}")
+ print(f" File size: {len(data)} bytes")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Generate identity CNN v2 weights')
+ parser.add_argument('output', type=str, nargs='?',
+ default=str(PROJECT_ROOT / 'workspaces/main/weights/cnn_v2_identity.bin'),
+ help='Output .bin file path')
+ parser.add_argument('--kernel-size', type=int, default=1,
+ help='Kernel size (default: 1×1)')
+ parser.add_argument('--mip-level', type=int, default=0,
+ help='Mip level for p0-p3 features (default: 0)')
+ parser.add_argument('--mix', action='store_true',
+ help='Mix mode: 50-50 blend of p0-p3 and p4-p7')
+ parser.add_argument('--p47', action='store_true',
+ help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3')
+
+ args = parser.parse_args()
+
+ print("=== Identity Weight Generator ===\n")
+ generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47)
+ print("\nDone!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cnn_v2/training/train_cnn_v2.py b/cnn_v2/training/train_cnn_v2.py
new file mode 100755
index 0000000..9e5df2f
--- /dev/null
+++ b/cnn_v2/training/train_cnn_v2.py
@@ -0,0 +1,472 @@
+#!/usr/bin/env python3
+"""CNN v2 Training Script - Uniform 12D→4D Architecture
+
+Architecture:
+- Static features (8D): p0-p3 (parametric), uv_x, uv_y, sin(10×uv_x), bias
+- Input RGBD (4D): original image mip 0
+- All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+- Per-layer kernel sizes (e.g., 1×1, 3×3, 5×5)
+- Uniform layer structure with bias=False (bias in static features)
+"""
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+from PIL import Image
+import time
+import cv2
+
+
+def compute_static_features(rgb, depth=None, mip_level=0):
+ """Generate 8D static features (parametric + spatial).
+
+ Args:
+ rgb: (H, W, 3) RGB image [0, 1]
+ depth: (H, W) depth map [0, 1], optional (defaults to 1.0 = far plane)
+ mip_level: Mip level for p0-p3 (0=original, 1=half, 2=quarter, 3=eighth)
+
+ Returns:
+ (H, W, 8) static features: [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
+
+ Note: p0-p3 are parametric features from mip level. p3 uses depth (alpha channel) or 1.0
+
+ TODO: Binary format should support arbitrary layout and ordering for feature vector (7D),
+ alongside mip-level indication. Current layout is hardcoded as:
+ [p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias]
+ Future: Allow experimentation with different feature combinations without shader recompilation.
+ Examples: [R, G, B, dx, dy, uv_x, bias] or [mip1.r, mip2.g, laplacian, uv_x, sin20_x, bias]
+ """
+ h, w = rgb.shape[:2]
+
+ # Generate mip level for p0-p3
+ if mip_level > 0:
+ # Downsample to mip level
+ mip_rgb = rgb.copy()
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrDown(mip_rgb)
+ # Upsample back to original size
+ for _ in range(mip_level):
+ mip_rgb = cv2.pyrUp(mip_rgb)
+ # Crop/pad to exact original size if needed
+ if mip_rgb.shape[:2] != (h, w):
+ mip_rgb = cv2.resize(mip_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
+ else:
+ mip_rgb = rgb
+
+ # Parametric features (p0-p3) from mip level
+ p0 = mip_rgb[:, :, 0].astype(np.float32)
+ p1 = mip_rgb[:, :, 1].astype(np.float32)
+ p2 = mip_rgb[:, :, 2].astype(np.float32)
+ p3 = depth.astype(np.float32) if depth is not None else np.ones((h, w), dtype=np.float32) # Default 1.0 = far plane
+
+ # UV coordinates (normalized [0, 1])
+ uv_x = np.linspace(0, 1, w)[None, :].repeat(h, axis=0).astype(np.float32)
+ uv_y = np.linspace(0, 1, h)[:, None].repeat(w, axis=1).astype(np.float32)
+
+ # Multi-frequency position encoding
+ sin20_y = np.sin(20.0 * uv_y).astype(np.float32)
+
+ # Bias dimension (always 1.0) - replaces Conv2d bias parameter
+ bias = np.ones((h, w), dtype=np.float32)
+
+ # Stack: [p0, p1, p2, p3, uv.x, uv.y, sin20_y, bias]
+ features = np.stack([p0, p1, p2, p3, uv_x, uv_y, sin20_y, bias], axis=-1)
+ return features
+
+
+class CNNv2(nn.Module):
+ """CNN v2 - Uniform 12D→4D Architecture
+
+ All layers: input RGBD (4D) + static (8D) = 12D → 4 channels
+ Per-layer kernel sizes supported (e.g., [1, 3, 5])
+ Uses bias=False (bias integrated in static features as 1.0)
+
+ TODO: Add quantization-aware training (QAT) for 8-bit weights
+ - Use torch.quantization.QuantStub/DeQuantStub
+ - Train with fake quantization to adapt to 8-bit precision
+ - Target: ~1.3 KB weights (vs 2.6 KB with f16)
+ """
+
+ def __init__(self, kernel_sizes, num_layers=3):
+ super().__init__()
+ if isinstance(kernel_sizes, int):
+ kernel_sizes = [kernel_sizes] * num_layers
+ assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
+
+ self.kernel_sizes = kernel_sizes
+ self.num_layers = num_layers
+ self.layers = nn.ModuleList()
+
+ # All layers: 12D input (4 RGBD + 8 static) → 4D output
+ for kernel_size in kernel_sizes:
+ self.layers.append(
+ nn.Conv2d(12, 4, kernel_size=kernel_size,
+ padding=kernel_size//2, bias=False)
+ )
+
+ def forward(self, input_rgbd, static_features):
+ """Forward pass with uniform 12D→4D layers.
+
+ Args:
+ input_rgbd: (B, 4, H, W) input image RGBD (mip 0)
+ static_features: (B, 8, H, W) static features
+
+ Returns:
+ (B, 4, H, W) RGBA output [0, 1]
+ """
+ # Layer 0: input RGBD (4D) + static (8D) = 12D
+ x = torch.cat([input_rgbd, static_features], dim=1)
+ x = self.layers[0](x)
+ x = torch.sigmoid(x) # Soft [0,1] for layer 0
+
+ # Layer 1+: previous (4D) + static (8D) = 12D
+ for i in range(1, self.num_layers):
+ x_input = torch.cat([x, static_features], dim=1)
+ x = self.layers[i](x_input)
+ if i < self.num_layers - 1:
+ x = F.relu(x)
+ else:
+ x = torch.sigmoid(x) # Soft [0,1] for final layer
+
+ return x
+
+
+class PatchDataset(Dataset):
+ """Patch-based dataset extracting salient regions from images."""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris', mip_level=0):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+ self.mip_level = mip_level
+
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ print(f"Found {len(self.input_paths)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.input_paths) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.input_paths) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points on original image.
+
+ TODO: Add random sampling to training vectors
+ - In addition to salient points, incorporate randomly-located samples
+ - Default: 10% random samples, 90% salient points
+ - Prevents overfitting to only high-gradient regions
+ - Improves generalization across entire image
+ - Configurable via --random-sample-percent parameter
+ """
+ gray = cv2.cvtColor((img_array * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ corners = None
+ if self.detector == 'harris':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ # Load original images (no resize)
+ input_img = np.array(Image.open(self.input_paths[img_idx]).convert('RGB')) / 255.0
+ target_pil = Image.open(self.target_paths[img_idx])
+ target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
+
+ # Detect salient points on original image (use RGB only)
+ salient_points = self._detect_salient_points(input_img)
+ cx, cy = salient_points[patch_idx]
+
+ # Extract patch
+ half_patch = self.patch_size // 2
+ y1, y2 = cy - half_patch, cy + half_patch
+ x1, x2 = cx - half_patch, cx + half_patch
+
+ input_patch = input_img[y1:y2, x1:x2]
+ target_patch = target_img[y1:y2, x1:x2] # RGBA
+
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_patch[:, :, 3] if target_patch.shape[2] == 4 else None
+
+ # Compute static features for patch
+ static_feat = compute_static_features(input_patch.astype(np.float32), depth=depth, mip_level=self.mip_level)
+
+ # Input RGBD (mip 0) - add depth channel
+ input_rgbd = np.concatenate([input_patch, np.zeros((self.patch_size, self.patch_size, 1))], axis=-1)
+
+ # Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_patch.astype(np.float32)).permute(2, 0, 1) # RGBA from image
+
+ return input_rgbd, static_feat, target
+
+
+class ImagePairDataset(Dataset):
+ """Dataset of input/target image pairs (full-image mode)."""
+
+ def __init__(self, input_dir, target_dir, target_size=(256, 256), mip_level=0):
+ self.input_paths = sorted(Path(input_dir).glob("*.png"))
+ self.target_paths = sorted(Path(target_dir).glob("*.png"))
+ self.target_size = target_size
+ self.mip_level = mip_level
+ assert len(self.input_paths) == len(self.target_paths), \
+ f"Mismatch: {len(self.input_paths)} inputs vs {len(self.target_paths)} targets"
+
+ def __len__(self):
+ return len(self.input_paths)
+
+ def __getitem__(self, idx):
+ # Load and resize images to fixed size
+ input_pil = Image.open(self.input_paths[idx]).convert('RGB')
+ target_pil = Image.open(self.target_paths[idx])
+
+ # Resize to target size
+ input_pil = input_pil.resize(self.target_size, Image.LANCZOS)
+ target_pil = target_pil.resize(self.target_size, Image.LANCZOS)
+
+ input_img = np.array(input_pil) / 255.0
+ target_img = np.array(target_pil.convert('RGBA')) / 255.0 # Preserve alpha
+
+ # Extract depth from target alpha channel (or default to 1.0)
+ depth = target_img[:, :, 3] if target_img.shape[2] == 4 else None
+
+ # Compute static features
+ static_feat = compute_static_features(input_img.astype(np.float32), depth=depth, mip_level=self.mip_level)
+
+ # Input RGBD (mip 0) - add depth channel
+ h, w = input_img.shape[:2]
+ input_rgbd = np.concatenate([input_img, np.zeros((h, w, 1))], axis=-1)
+
+ # Convert to tensors (C, H, W)
+ input_rgbd = torch.from_numpy(input_rgbd.astype(np.float32)).permute(2, 0, 1)
+ static_feat = torch.from_numpy(static_feat).permute(2, 0, 1)
+ target = torch.from_numpy(target_img.astype(np.float32)).permute(2, 0, 1) # RGBA from image
+
+ return input_rgbd, static_feat, target
+
+
+def train(args):
+ """Train CNN v2 model."""
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Training on {device}")
+
+ # Create dataset (patch-based or full-image)
+ if args.full_image:
+ print(f"Mode: Full-image (resized to {args.image_size}x{args.image_size})")
+ target_size = (args.image_size, args.image_size)
+ dataset = ImagePairDataset(args.input, args.target, target_size=target_size, mip_level=args.mip_level)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+ else:
+ print(f"Mode: Patch-based ({args.patch_size}x{args.patch_size} patches)")
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector,
+ mip_level=args.mip_level)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Parse kernel sizes
+ kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+ if len(kernel_sizes) == 1:
+ kernel_sizes = kernel_sizes * args.num_layers
+ else:
+ # When multiple kernel sizes provided, derive num_layers from list length
+ args.num_layers = len(kernel_sizes)
+
+ # Create model
+ model = CNNv2(kernel_sizes=kernel_sizes, num_layers=args.num_layers).to(device)
+ total_params = sum(p.numel() for p in model.parameters())
+ kernel_desc = ','.join(map(str, kernel_sizes))
+ print(f"Model: {args.num_layers} layers, kernel sizes [{kernel_desc}], {total_params} weights")
+ print(f"Using mip level {args.mip_level} for p0-p3 features")
+
+ # Optimizer and loss
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ criterion = nn.MSELoss()
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs...")
+ start_time = time.time()
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ epoch_loss = 0.0
+
+ for input_rgbd, static_feat, target in dataloader:
+ input_rgbd = input_rgbd.to(device)
+ static_feat = static_feat.to(device)
+ target = target.to(device)
+
+ optimizer.zero_grad()
+ output = model(input_rgbd, static_feat)
+
+ # Compute loss (grayscale or RGBA)
+ if args.grayscale_loss:
+ # Convert RGBA to grayscale: Y = 0.299*R + 0.587*G + 0.114*B
+ output_gray = 0.299 * output[:, 0:1] + 0.587 * output[:, 1:2] + 0.114 * output[:, 2:3]
+ target_gray = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
+ loss = criterion(output_gray, target_gray)
+ else:
+ loss = criterion(output, target)
+
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+
+ # Print loss at every epoch (overwrite line with \r)
+ elapsed = time.time() - start_time
+ print(f"\rEpoch {epoch:4d}/{args.epochs} | Loss: {avg_loss:.6f} | Time: {elapsed:.1f}s", end='', flush=True)
+
+ # Save checkpoint
+ if args.checkpoint_every > 0 and epoch % args.checkpoint_every == 0:
+ print() # Newline before checkpoint message
+ checkpoint_path = Path(args.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
+ }
+ }, checkpoint_path)
+ print(f" → Saved checkpoint: {checkpoint_path}")
+
+ # Always save final checkpoint
+ print() # Newline after training
+ final_checkpoint = Path(args.checkpoint_dir) / f"checkpoint_epoch_{args.epochs}.pth"
+ final_checkpoint.parent.mkdir(parents=True, exist_ok=True)
+ torch.save({
+ 'epoch': args.epochs,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'config': {
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.num_layers,
+ 'mip_level': args.mip_level,
+ 'grayscale_loss': args.grayscale_loss,
+ 'features': ['p0', 'p1', 'p2', 'p3', 'uv.x', 'uv.y', 'sin20_y', 'bias']
+ }
+ }, final_checkpoint)
+ print(f" → Saved final checkpoint: {final_checkpoint}")
+
+ print(f"\nTraining complete! Total time: {time.time() - start_time:.1f}s")
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN v2 with parametric static features')
+ parser.add_argument('--input', type=str, required=True, help='Input images directory')
+ parser.add_argument('--target', type=str, required=True, help='Target images directory')
+
+ # Training mode
+ parser.add_argument('--full-image', action='store_true',
+ help='Use full-image mode (resize all images)')
+ parser.add_argument('--image-size', type=int, default=256,
+ help='Full-image mode: resize to this size (default: 256)')
+
+ # Patch-based mode (default)
+ parser.add_argument('--patch-size', type=int, default=32,
+ help='Patch mode: patch size (default: 32)')
+ parser.add_argument('--patches-per-image', type=int, default=64,
+ help='Patch mode: patches per image (default: 64)')
+ parser.add_argument('--detector', type=str, default='harris',
+ choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Patch mode: salient point detector (default: harris)')
+ # TODO: Add --random-sample-percent parameter (default: 10)
+ # Mix salient points with random samples for better generalization
+
+ # Model architecture
+ parser.add_argument('--kernel-sizes', type=str, default='3',
+ help='Comma-separated kernel sizes per layer (e.g., "3,5,3"), single value replicates (default: 3)')
+ parser.add_argument('--num-layers', type=int, default=3,
+ help='Number of CNN layers (default: 3)')
+ parser.add_argument('--mip-level', type=int, default=0, choices=[0, 1, 2, 3],
+ help='Mip level for p0-p3 features: 0=original, 1=half, 2=quarter, 3=eighth (default: 0)')
+
+ # Training parameters
+ parser.add_argument('--epochs', type=int, default=5000, help='Training epochs')
+ parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
+ parser.add_argument('--grayscale-loss', action='store_true',
+ help='Compute loss on grayscale (Y = 0.299*R + 0.587*G + 0.114*B) instead of RGBA')
+ parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
+ help='Checkpoint directory')
+ parser.add_argument('--checkpoint-every', type=int, default=1000,
+ help='Save checkpoint every N epochs (0 = disable)')
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == '__main__':
+ main()