summaryrefslogtreecommitdiff
path: root/cnn_v1
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v1')
-rw-r--r--cnn_v1/README.md64
-rw-r--r--cnn_v1/docs/CNN.md79
-rw-r--r--cnn_v1/docs/CNN_BIAS_FIX_2026-02.md85
-rw-r--r--cnn_v1/docs/CNN_DEBUG.md43
-rw-r--r--cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md189
-rw-r--r--cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md136
-rw-r--r--cnn_v1/docs/CNN_TEST_TOOL.md244
-rw-r--r--cnn_v1/docs/CNN_V1_EFFECT.md400
-rw-r--r--cnn_v1/shaders/cnn_activation.wgsl18
-rw-r--r--cnn_v1/shaders/cnn_conv1x1.wgsl100
-rw-r--r--cnn_v1/shaders/cnn_conv3x3.wgsl100
-rw-r--r--cnn_v1/shaders/cnn_conv5x5.wgsl101
-rw-r--r--cnn_v1/shaders/cnn_conv7x7.wgsl53
-rw-r--r--cnn_v1/shaders/cnn_layer.wgsl55
-rw-r--r--cnn_v1/shaders/cnn_weights_generated.wgsl302
-rw-r--r--cnn_v1/src/cnn_v1_effect.cc129
-rw-r--r--cnn_v1/src/cnn_v1_effect.h53
-rwxr-xr-xcnn_v1/training/train_cnn.py943
18 files changed, 3094 insertions, 0 deletions
diff --git a/cnn_v1/README.md b/cnn_v1/README.md
new file mode 100644
index 0000000..052f22a
--- /dev/null
+++ b/cnn_v1/README.md
@@ -0,0 +1,64 @@
+# CNN v1: Original Post-Processing Neural Network
+
+**Architecture:** 3-layer convolution, generated shader weights
+**Status:** Active (used in timeline), legacy architecture
+
+## Overview
+
+Original CNN implementation with per-layer WGSL shaders. Supports multiple kernel sizes (1×1, 3×3, 5×5, 7×7) with generated weight arrays.
+
+**For new work, use CNN v2** (`cnn_v2/`) which provides:
+- Storage buffer architecture (~3.2 KB vs generated WGSL)
+- 7D static features (RGBD + UV + sin + bias)
+- Sigmoid activation with stable training
+- Dynamic layer configuration
+
+## Quick Reference
+
+**Training:**
+```bash
+./cnn_v1/training/train_cnn.py --input training/input --target training/output \
+ --layers 3 --kernel_sizes 3,5,3 --epochs 5000
+```
+
+**Integration:**
+- **C++:** `cnn_v1/src/cnn_effect.{h,cc}`
+- **Assets:** `workspaces/main/assets.txt` (lines 40-46)
+- **Timeline:** `workspaces/main/timeline.seq` (CNNEffect)
+
+## Documentation
+
+- [CNN.md](docs/CNN.md) - Architecture overview
+- [CNN_V1_EFFECT.md](docs/CNN_V1_EFFECT.md) - Implementation details
+- [CNN_TEST_TOOL.md](docs/CNN_TEST_TOOL.md) - Testing guide
+- [CNN_DEBUG.md](docs/CNN_DEBUG.md) - Debugging notes
+
+## Directory Structure
+
+```
+cnn_v1/
+├── README.md # This file
+├── src/
+│ ├── cnn_effect.h # Effect header
+│ └── cnn_effect.cc # Effect implementation
+├── shaders/ # WGSL shaders (7 files)
+├── training/ # Python training script
+└── docs/ # Documentation (7 markdown files)
+```
+
+## Differences from CNN v2
+
+| Feature | CNN v1 | CNN v2 |
+|---------|--------|--------|
+| Architecture | Generated WGSL weights | Storage buffer weights |
+| Input Features | 4D (RGBA/prev layer) | 12D (4D + 8D static) |
+| Activation | ReLU | Sigmoid + ReLU |
+| Size | ~Variable (WGSL gen) | ~3.2 KB (binary) |
+| Training | Full-image | Patch-based (default) |
+| Layer Config | Compile-time | Runtime (dynamic) |
+
+## Migration Notes
+
+CNN v1 remains in the timeline for historical validation. For new effects or experiments, use CNN v2's enhanced feature set and compact binary format.
+
+See `cnn_v2/docs/CNN_V2.md` for CNN v2 architecture details.
diff --git a/cnn_v1/docs/CNN.md b/cnn_v1/docs/CNN.md
new file mode 100644
index 0000000..5d9a667
--- /dev/null
+++ b/cnn_v1/docs/CNN.md
@@ -0,0 +1,79 @@
+# Convolutional Neural Net Shader (CNN) post-processing
+
+**Status:** ✅ Foundation implemented (single-layer, expandable to multi-pass)
+
+## Idea
+
+Have the input 3d scene be processed by a multi-layer CNN trained on the side.
+Input: some rendered scene.
+Output: 'stylized' scene with CNN post-processing.
+
+**See `CNN_V1_EFFECT.md` for implementation details, usage, and API reference.**
+
+## Shader implementation
+
+### input / output
+
+Need 1 texture buffer per CNN layer.
+Input (r,g,b,1/z) for layer 0 (render 3d scene), or output from layer N-1 for layer N.
+output: (r,g,b, alpha). Don't need the 1/z information (can be fetched from input)
+
+### size of one layer
+
+Notation:
+S: the number of input samples from layer N-1.
+Example: 3x3 input -> S = 3x3 = 9.
+
+Each S samples is 4 values (r,g,b, w=1/z).
+
+Each sample is processed by a mat4 matrix. 4 input => 4 output.
+
+Weight matrix = S x mat4
+
+Final bias: 4 values.
+
+WGSL code example: See file CNN.shader
+
+### Layers
+
+we need 3 or 4 layer ?
+Several different shaders for each layer.
+Ping-pong for input/output texture buffer between each layers?
+
+## Implementation Status
+
+**Completed:**
+- ✅ Modular WGSL shader architecture (6 snippet files)
+- ✅ CNNEffect C++ class (single-layer rendering)
+- ✅ ShaderComposer integration (#include resolution)
+- ✅ Asset registration (7 new shader assets)
+- ✅ Test coverage (test_demo_effects.cc)
+- ✅ Placeholder identity weights for testing
+
+**Size:** ~3-4 KB shader code + ~2-4 KB weights = **5-8 KB total**
+
+**Pending:**
+- ⏳ Training script (`scripts/train_cnn.py`) to generate real weights
+- ⏳ Multi-layer rendering with ping-pong textures
+- ⏳ Weight quantization for size optimization
+
+---
+
+## Training (To Be Implemented)
+
+The layer weight/bias data are hard-coded in the shaders.
+Training workflow:
+
+1. Prepare image pairs (before: raw render, after: target style)
+2. Run `python scripts/train_cnn.py --input scene.png --target stylized.png`
+3. Script generates `cnn_weights_generated.wgsl`
+4. Rebuild: `cmake --build build -j4`
+
+**Reference:** File `CNN.py` contains training example (needs adaptation).
+
+Need a repository of reference image pairs (before/after) for training and validation.
+Each input image is randomly sampled into 3×3 patch of (r,g,b,1/z) input samples.
+And trained to match the (r,g,b,a) output.
+
+Training generates the .wgsl code for layers' shaders.
+
diff --git a/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md b/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md
new file mode 100644
index 0000000..26db8eb
--- /dev/null
+++ b/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md
@@ -0,0 +1,85 @@
+# CNN Bias Accumulation Fix (2026-02-11)
+
+## Problem
+Bias was being added multiple times in shader convolution loops (once per kernel position), causing mismatch between PyTorch training and WGSL inference.
+
+## Root Cause
+**Location**: `training/train_cnn.py:381, 398`
+
+When exporting weights to WGSL, bias was replicated for every kernel position. The shader loops through positions doing:
+```wgsl
+sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); // in1.w = 1.0
+```
+
+For 3×3 kernel (9 positions), bias added 9×. For 5×5, added 25×.
+
+## Fix
+Divide bias by `num_positions` during export:
+```python
+# Final layer (7→1)
+v1.append(f"{bias[0] / num_positions:.6f}")
+
+# Inner layers (7→4)
+v1.append(f"{bias[out_c] / num_positions:.6f}")
+```
+
+Shader accumulates bias × num_positions = original bias (correct).
+
+---
+
+## Additional Improvements
+
+### 1. RGBA Output Support
+**train_cnn.py**: Now saves 4-channel RGBA PNG preserving alpha from input:
+```python
+alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy()
+output_rgba = np.concatenate([output, alpha], axis=2)
+Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
+```
+
+Intermediate layers also save RGBA if 4-channel.
+
+### 2. Debug Hex Output
+**Both tools** support `--debug-hex` to print first 8 pixels as hex:
+```bash
+./training/train_cnn.py --infer input.png --export-only checkpoint.pth --debug-hex
+./build/cnn_test input.png output.png --debug-hex
+```
+
+Output format: `[0] 0xRRGGBBAA` for pixel-level comparison.
+
+### 3. Cleanup
+Removed sRGB/linear_png debug code from `cnn_test.cc` (simplified PNG saving).
+
+---
+
+## Files Modified
+- `training/train_cnn.py`: Bias fix, RGBA output, --debug-hex
+- `tools/cnn_test.cc`: --debug-hex, remove linear_png
+- `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl`: Regenerated with fixed bias
+
+## Testing
+```bash
+# Train with fixed export
+./training/train_cnn.py --input training/input/ --target training/output/ \
+ --layers 3 --kernel_sizes 3,3,3 --epochs 5000
+
+# Generate ground truth
+./training/train_cnn.py --infer input.png --export-only checkpoint.pth \
+ --output ground_truth.png --debug-hex
+
+# Run GPU tool
+./build/cnn_test input.png tool_output.png --debug-hex
+
+# Compare hex output for first 8 pixels
+```
+
+---
+
+## Status
+✅ Bias accumulation bug fixed
+✅ RGBA output with alpha preservation
+✅ Debug hex comparison tool
+✅ Weights regenerated
+
+Commit: `8ff8c56`
diff --git a/cnn_v1/docs/CNN_DEBUG.md b/cnn_v1/docs/CNN_DEBUG.md
new file mode 100644
index 0000000..ba220a0
--- /dev/null
+++ b/cnn_v1/docs/CNN_DEBUG.md
@@ -0,0 +1,43 @@
+# CNN Effect Black Screen Bug - Resolution (2026-02)
+
+## Problem
+CNN post-processing effect showed black screen when activated at 11.50s, despite scene rendering correctly before CNN started.
+
+## Root Causes
+
+### Bug 1: Framebuffer Capture Timing
+**Location**: `src/gpu/effect.cc`
+**Issue**: Capture ran INSIDE post-effect loop after ping-pong buffer swaps. CNN layers 1+ captured wrong buffer (output being written to, not scene).
+**Fix**: Moved capture before loop starts (lines 308-346). Capture now copies `framebuffer_a` to `captured_frame` auxiliary texture ONCE before any post-effects run.
+
+### Bug 2: Missing Uniforms Update ⚠️ CRITICAL
+**Location**: `src/effects/cnn_effect.cc`
+**Issue**: `CNNEffect::update_bind_group()` never updated `uniforms_` buffer. `uniforms.resolution` uninitialized (0,0 or garbage) → UV calculation `p.xy / uniforms.resolution` produced NaN → all texture samples black.
+**Fix**: Added uniforms update before bind group creation (lines 132-142):
+```cpp
+const CommonPostProcessUniforms u = {
+ .resolution = {(float)width_, (float)height_},
+ .aspect_ratio = (float)width_ / (float)height_,
+ .time = 0.0f,
+ .beat = 0.0f,
+ .audio_intensity = 0.0f,
+};
+uniforms_.update(ctx_.queue, u);
+```
+
+## Key Lessons
+
+1. **All post-process effects MUST update `uniforms_` buffer** - Required for UV calculations and shader parameters
+2. **Framebuffer capture timing is critical** - Must happen before post-chain ping-pong starts
+3. **Uninitialized uniforms cause silent failures** - Produces black output without validation errors
+4. **Post-effects must render or chain breaks** - `loadOp=Load` preserves previous (black) content if no draw call executes
+
+## Files Modified
+- `src/gpu/effect.cc`: Lines 308-346 (capture timing)
+- `src/effects/cnn_effect.cc`: Lines 132-142 (uniforms update)
+
+## Verification
+Test: `demo64k --seek 11.5`
+- ✅ Scene visible with RotatingCube
+- ✅ CNN stylization applied
+- ✅ All 3 layers process with correct original texture reference
diff --git a/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md b/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md
new file mode 100644
index 0000000..8664157
--- /dev/null
+++ b/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md
@@ -0,0 +1,189 @@
+# CNN Shader Flatten Mode - Technical Analysis
+
+**Status:** Analysis complete - flatten mode NOT RECOMMENDED
+
+**Date:** February 2026
+
+---
+
+## Context
+
+Current CNN architecture uses **3 sequential render passes** (linear chaining):
+- **Layer 0:** 5×5 conv (7→4 channels) → framebuffer
+- **Layer 1:** 3×3 conv (7→4 channels) → reads L0 output, writes framebuffer
+- **Layer 2:** 3×3 conv (7→1 channel) → reads L1 output, blends with original
+
+Proposed **"flatten mode"**: Collapse all layers into **single shader pass** using intermediate arrays, eliminating framebuffer read/write between layers.
+
+---
+
+## Current Architecture
+
+**Shader Structure:**
+- 1 pipeline with layer branching (`layer_index` uniform)
+- 5 bindings: sampler, input texture, uniforms, layer params, original capture
+- Total shader size: ~8 KB (snippets + weights)
+
+**Performance Profile:**
+- 3 render pass dispatches
+- 2 framebuffer writes + reads between layers
+- Memory bandwidth: ~2× framebuffer size per layer
+- Register pressure: Low (per-layer isolation)
+
+**Weight Buffer:** 290 vec4s (4.6 KB) - already unified
+
+---
+
+## Flatten Approaches Evaluated
+
+### Option A: Full Flatten (All 3 Layers)
+
+**Cascading Receptive Field:**
+
+To compute final output at position (x, y):
+- Layer 2 needs 3×3 neighborhood of Layer 1 outputs
+- Each Layer 1 output needs 3×3 neighborhood of Layer 0 outputs
+- Each Layer 0 output needs 5×5 neighborhood of input samples
+
+**Effective input sampling:** 9×9 pixels (vs current 5×5 max)
+
+**Intermediate Storage (per thread/pixel):**
+```
+Layer 0 outputs: 5×5 positions × 4 channels = 100 floats
+Layer 1 outputs: 3×3 positions × 4 channels = 36 floats
+ TOTAL = 136 floats (544 bytes)
+```
+
+**GPU Register Pressure:**
+- Modern GPUs: 32-64 KB registers per SM, shared across warps
+- 544 bytes/thread → max 64 threads/SM (**low occupancy**)
+- Current multi-pass: ~4-8 bytes/thread (high occupancy)
+
+**Pros:**
+- 1 dispatch vs 3 (reduce CPU overhead)
+- Zero framebuffer bandwidth between layers
+
+**Cons:**
+- **Severe register pressure** (10-20× increase)
+- Reduced occupancy → potential performance loss
+- Complex shader (harder debug, larger binary)
+- 9×9 input sampling
+
+**Assessment:** ❌ **Not Recommended**
+Register cost outweighs bandwidth savings.
+
+---
+
+### Option B: Partial Flatten (Layers 1 + 2)
+
+Keep Layer 0 separate, flatten only Layers 1 and 2.
+
+**Pass Structure:**
+1. **Pass 1:** Layer 0 (5×5 conv) → framebuffer
+2. **Pass 2 (flattened):** Compute Layer 1 + Layer 2 in single shader
+
+**Intermediate Storage:**
+```
+Layer 0 samples: 3×3 × 4 = 36 floats (read once)
+Layer 1 outputs: 3×3 × 4 = 36 floats (computed)
+ TOTAL = 72 floats (288 bytes)
+```
+
+**Receptive Field:** 5×5 Layer 0 samples required for 3×3 Layer 1 outputs
+
+**Pros:**
+- 2 passes vs 3 (33% reduction)
+- 1 framebuffer write saved
+- More manageable register usage
+
+**Cons:**
+- Still significant register pressure (288 bytes vs ~8 bytes baseline)
+- Medium complexity increase
+- Layer 0 (heaviest kernel) still separate
+
+**Assessment:** ⚠️ **Marginal Benefit**
+Saves 1 pass but register cost still high.
+
+---
+
+### Option C: Keep Current Multi-Pass ✅
+
+**Rationale:**
+- Current architecture well-suited to GPU design (high throughput via parallelism)
+- Minimal register usage → high occupancy → hides memory latency
+- Framebuffer bandwidth cost < register pressure cost
+- Clean separation aids debugging/iteration
+- Modular (easy to add/remove layers)
+
+**Alternative Optimizations (if bandwidth critical):**
+1. Merge passes via render pass load/store ops (Vulkan subpasses)
+2. Reduce intermediate channel count (4→3 or 2)
+3. Hybrid: Compute shaders + workgroup shared memory
+4. Layer pruning (2-layer vs 3-layer quality comparison)
+
+---
+
+## Recommendation
+
+**✅ Keep current multi-pass architecture**
+
+### Decision Matrix
+
+| Factor | Multi-Pass | Partial Flatten | Full Flatten |
+|--------|-----------|----------------|--------------|
+| Register pressure | ✅ Low | ⚠️ High | ❌ Extreme |
+| Occupancy | ✅ High | ⚠️ Medium | ❌ Low |
+| Memory bandwidth | ⚠️ Medium | ✅ Lower | ✅ Lowest |
+| Shader complexity | ✅ Simple | ⚠️ Medium | ❌ High |
+| Debuggability | ✅ Easy | ⚠️ Harder | ❌ Very hard |
+| Binary size | ✅ Small | ⚠️ Larger | ⚠️ Largest |
+
+**Modern GPU Architecture Favors:**
+- High parallelism (many small threads) over complex threads
+- Hiding latency via occupancy over minimizing operations
+- Memory bandwidth via caching, not elimination
+
+---
+
+## Alternative: Compute Shader + Shared Memory
+
+**If bandwidth becomes critical:**
+- Use compute shader with workgroup shared memory
+- Load tile + halos into shared memory (9×9 input samples)
+- Compute all 3 layers for tile interior (avoids redundant sampling)
+- Requires explicit synchronization (`workgroupBarrier`)
+
+**Trade-offs:**
+- ✅ Low register pressure + low bandwidth
+- ❌ Compute pipeline complexity (no render pass integration)
+- ❌ Tile edge handling
+- ❌ Larger code size
+
+---
+
+## Conclusion
+
+Current 3-pass architecture is **appropriate for demo64k**:
+- Size-efficient (modular shaders)
+- Performance adequate (bandwidth not bottleneck)
+- Maintainable (clean layer isolation)
+
+**Flatten mode not recommended** unless profiling reveals specific bandwidth constraint.
+
+### Size Optimization Alternatives (Better ROI)
+
+If size optimization critical, focus on:
+1. **Weight quantization:** 4.6 KB → ~2 KB (8-bit or 4-bit quantization)
+2. **Kernel size reduction:** 5×5 → 3×3 for Layer 0 (200 vec4s → 72 vec4s)
+3. **Channel reduction:** 7 inputs → 4 inputs (remove UV/grayscale channels)
+
+These yield better size/performance than shader architecture changes.
+
+---
+
+## References
+
+- `CNN_V1_EFFECT.md` - CNN implementation details
+- `CNN.md` - High-level CNN design
+- `../src/cnn_effect.cc` - Current implementation
+- `workspaces/main/shaders/cnn_*.wgsl` - Shader snippets
diff --git a/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md b/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md
new file mode 100644
index 0000000..3439f2c
--- /dev/null
+++ b/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md
@@ -0,0 +1,136 @@
+# CNN RGBD→Grayscale Architecture Implementation
+
+## Summary
+
+Implemented CNN architecture upgrade: RGBD input → grayscale output with 7-channel augmented input.
+
+## Changes Made
+
+### Architecture
+
+**Input:** RGBD (4 channels: RGB + inverse depth D=1/z)
+**Output:** Grayscale (1 channel)
+**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1]
+
+**Layer Configuration:**
+- Inner layers (0..N-2): Conv2d(7→4) - output RGBD with tanh activation
+- Final layer (N-1): Conv2d(7→1) - output grayscale, no activation
+
+### Input Normalization (all to [-1,1])
+
+- **RGBD:** `(rgbd - 0.5) * 2`
+- **UV coords:** `(uv - 0.5) * 2`
+- **Grayscale:** `dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722))` (computed once, passed as parameter)
+
+**Rationale:** Zero-centered inputs for tanh activation, better gradient flow.
+
+### Modified Files
+
+**Training (`/Users/skal/demo/training/train_cnn.py`):**
+1. Removed `CoordConv2d` class
+2. Updated `SimpleCNN`:
+ - Inner layers: `Conv2d(7, 4)` - RGBD output
+ - Final layer: `Conv2d(7, 1)` - grayscale output
+3. Updated `forward()`:
+ - Normalize RGBD/coords/gray to [-1,1]
+ - Concatenate 7-channel input for each layer
+ - Apply tanh (inner) or none (final)
+ - Denormalize final output
+4. Updated `export_weights_to_wgsl()`:
+ - Inner: `array<array<f32, 8>, 36>` (9 pos × 4 ch × 8 values)
+ - Final: `array<array<f32, 8>, 9>` (9 pos × 8 values)
+5. Updated `generate_layer_shader()`:
+ - Use `cnn_conv3x3_7to4` for inner layers
+ - Use `cnn_conv3x3_7to1` for final layer
+ - Denormalize outputs from [-1,1] to [0,1]
+6. Updated `ImagePairDataset`:
+ - Load RGBA input (was RGB)
+
+**Shaders (`/Users/skal/demo/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl`):**
+1. Added `cnn_conv3x3_7to4()`:
+ - 7-channel input: [RGBD, uv_x, uv_y, gray] (gray passed as parameter)
+ - 4-channel output: RGBD
+ - Weights: `array<array<f32, 8>, 36>`
+2. Added `cnn_conv3x3_7to1()`:
+ - 7-channel input: [RGBD, uv_x, uv_y, gray] (gray passed as parameter)
+ - 1-channel output: grayscale
+ - Weights: `array<array<f32, 8>, 9>`
+3. Optimized: gray computed once in caller using `dot()`, not per-function
+
+**Documentation (`/Users/skal/demo/doc/CNN_EFFECT.md`):**
+1. Updated architecture section with RGBD→grayscale pipeline
+2. Updated training data requirements (RGBA input)
+3. Updated weight storage format
+
+### No C++ Changes
+
+CNNLayerParams and bind groups remain unchanged.
+
+## Data Flow
+
+1. Layer 0 captures original RGBD to `captured_frame`
+2. Each layer:
+ - Samples previous layer output (RGBD in [0,1])
+ - Normalizes RGBD to [-1,1]
+ - Computes gray once using `dot()` (fs_main level)
+ - Normalizes UV coords to [-1,1] (inside conv functions)
+ - Concatenates 7-channel input
+ - Applies convolution with layer-specific weights
+ - Outputs RGBD (inner) or grayscale (final) in [-1,1]
+ - Applies tanh (inner only)
+ - Denormalizes to [0,1] for texture storage
+ - Blends with original
+
+## Next Steps
+
+1. **Prepare RGBD training data:**
+ - Input: RGBA images (RGB + depth in alpha)
+ - Target: Grayscale stylized output
+
+2. **Train network:**
+ ```bash
+ python3 training/train_cnn.py \
+ --input training/input \
+ --target training/output \
+ --layers 3 \
+ --epochs 1000
+ ```
+
+3. **Verify generated shaders:**
+ - Check `cnn_weights_generated.wgsl` structure
+ - Check `cnn_layer.wgsl` uses new conv functions
+
+4. **Test in demo:**
+ ```bash
+ cmake --build build -j4
+ ./build/demo64k
+ ```
+
+## Design Rationale
+
+**Why [-1,1] normalization?**
+- Centered inputs for tanh (operates best around 0)
+- Better gradient flow
+- Standard ML practice for normalized data
+
+**Why RGBD throughout vs RGB?**
+- Depth information propagates through network
+- Enables depth-aware stylization
+- Consistent 4-channel processing
+
+**Why 7-channel input?**
+- Coordinates: position-dependent effects (vignettes)
+- Grayscale: luminance-aware processing
+- RGBD: full color+depth information
+- Enables richer feature learning
+
+## Testing Checklist
+
+- [ ] Train network with RGBD input data
+- [ ] Verify `cnn_weights_generated.wgsl` structure
+- [ ] Verify `cnn_layer.wgsl` uses `7to4`/`7to1` functions
+- [ ] Build demo without errors
+- [ ] Visual test: inner layers show RGBD evolution
+- [ ] Visual test: final layer produces grayscale
+- [ ] Visual test: blending works correctly
+- [ ] Compare quality with previous RGB→RGB architecture
diff --git a/cnn_v1/docs/CNN_TEST_TOOL.md b/cnn_v1/docs/CNN_TEST_TOOL.md
new file mode 100644
index 0000000..4307894
--- /dev/null
+++ b/cnn_v1/docs/CNN_TEST_TOOL.md
@@ -0,0 +1,244 @@
+# CNN Shader Testing Tool
+
+Standalone tool for validating trained CNN shaders with GPU-to-CPU readback. Supports both CNN v1 (render pipeline) and v2 (compute, storage buffer).
+
+---
+
+## Purpose
+
+- Validate trained weights against ground truth
+- Debug CNN layer behavior in isolation
+- Generate test outputs for training workflow
+- Match Python training script's inference mode
+
+---
+
+## Architecture
+
+**Two implementations:**
+
+1. **CNN v1** (render pipeline, texture atlas weights)
+ - 3 fixed layers
+ - RGBA16Float intermediates
+ - BGRA8Unorm final output
+
+2. **CNN v2** (compute shaders, storage buffer weights)
+ - Dynamic layer count from binary
+ - 7D static features (RGBD + UV + sin + bias)
+ - RGBA32Uint packed f16 intermediates
+ - Storage buffer: ~3-5 KB weights
+
+**Core GPU utility:** `src/gpu/texture_readback.{h,cc}`
+- Synchronous texture-to-CPU readback
+- Supports RGBA16Float, RGBA32Uint, BGRA8Unorm
+- Protected with STRIP_ALL (0 bytes in release)
+
+---
+
+## Usage
+
+```bash
+cnn_test input.png output.png [OPTIONS]
+
+OPTIONS:
+ --cnn-version N CNN version: 1 (default) or 2 (ignored with --weights)
+ --weights PATH Load weights from .bin (forces CNN v2, overrides layer config)
+ --blend F Final blend amount (0.0-1.0, default: 1.0)
+ --format ppm|png Output format (default: png)
+ --layers N Number of CNN layers (1-10, v1 only, default: 3, ignored with --weights)
+ --save-intermediates DIR Save intermediate layers to directory
+ --debug-hex Print first 8 pixels as hex (debug)
+ --help Show usage
+```
+
+**Examples:**
+```bash
+# CNN v1 (render pipeline, 3 layers)
+./build/cnn_test input.png output.png --cnn-version 1
+
+# CNN v2 (compute, storage buffer, uses asset system weights)
+./build/cnn_test input.png output.png --cnn-version 2
+
+# CNN v2 with runtime weight loading (loads layer config from .bin)
+./build/cnn_test input.png output.png --weights checkpoints/checkpoint_epoch_100.pth.bin
+
+# 50% blend with original (v2)
+./build/cnn_test input.png output.png --cnn-version 2 --blend 0.5
+
+# Debug hex dump
+./build/cnn_test input.png output.png --cnn-version 2 --debug-hex
+```
+
+**Important:** When using `--weights`, the layer count and kernel sizes are read from the binary file header, overriding any `--layers` or `--cnn-version` arguments.
+
+---
+
+## Implementation Details
+
+### Core Readback Utility
+
+**File:** `src/gpu/texture_readback.{h,cc}`
+
+**Function:**
+```cpp
+std::vector<uint8_t> read_texture_pixels(
+ WGPUInstance instance,
+ WGPUDevice device,
+ WGPUTexture texture,
+ int width,
+ int height);
+```
+
+**Features:**
+- Returns BGRA8 format (4 bytes per pixel)
+- Synchronous blocking operation
+- Cross-platform async callback handling (Win32 vs Native API)
+- Automatic staging buffer creation and cleanup
+
+**Refactored OffscreenRenderTarget:**
+```cpp
+std::vector<uint8_t> OffscreenRenderTarget::read_pixels() {
+#if !defined(STRIP_ALL)
+ return read_texture_pixels(instance_, device_, texture_, width_, height_);
+#else
+ return std::vector<uint8_t>();
+#endif
+}
+```
+
+### CNN v1 Pipeline (Render)
+
+**Fixed 3-layer architecture:**
+- Ping-pong RGBA16Float textures
+- CNNLayerParams (binding 3): layer_index, blend_amount
+- Shader composer resolves #include directives
+
+### CNN v2 Pipeline (Compute)
+
+**Dynamic layer architecture:**
+1. **Static features compute:** Generate 7D features (RGBD + UV + sin + bias)
+2. **Layer computes:** N layers from binary weights (3-5 typically)
+ - Storage buffer weights (read-only)
+ - RGBA32Uint packed f16 textures (ping-pong)
+ - CNNv2LayerParams: kernel_size, channels, weight_offset, blend
+3. **Readback:** RGBA32Uint → f16 decode → u8 clamp
+
+**Binary format:** Header (20B) + layer info (20B×N) + f16 weights
+
+**Weight Loading:**
+- **Without `--weights`:** Loads from asset system (`ASSET_WEIGHTS_CNN_V2`)
+- **With `--weights PATH`:** Loads from external `.bin` file (e.g., checkpoint exports)
+ - Layer count and kernel sizes parsed from binary header
+ - Overrides any `--layers` or `--cnn-version` arguments
+ - Enables runtime testing of training checkpoints without rebuild
+
+---
+
+## Build Integration
+
+**CMakeLists.txt:**
+
+1. Added `src/gpu/texture_readback.cc` to GPU_SOURCES (both sections)
+2. Tool target:
+```cmake
+add_executable(cnn_test
+ tools/cnn_test.cc
+ src/tests/common/webgpu_test_fixture.cc
+ src/tests/common/offscreen_render_target.cc
+ ${PLATFORM_SOURCES}
+ ${GEN_DEMO_CC})
+
+target_link_libraries(cnn_test PRIVATE
+ gpu util procedural ${DEMO_LIBS})
+
+add_dependencies(cnn_test generate_demo_assets)
+
+target_compile_definitions(cnn_test PRIVATE
+ STB_IMAGE_IMPLEMENTATION
+ STB_IMAGE_WRITE_IMPLEMENTATION)
+```
+
+**Build:**
+```bash
+cmake -S . -B build -DDEMO_BUILD_TOOLS=ON
+cmake --build build -j4
+```
+
+---
+
+## Validation Workflow (CNN v2)
+
+### 1. Train and Export
+```bash
+# Train and export weights
+./scripts/train_cnn_v2_full.sh --epochs 200 --batch-size 16
+```
+
+### 2. Tool Inference
+```bash
+# Run tool with v2
+./build/cnn_test training/input/img_000.png output.png --cnn-version 2
+```
+
+### 3. Visual Comparison
+Compare output.png with training/target_X/img_000.png
+
+---
+
+## Status
+
+**CNN v1:** Builds and runs, produces incorrect output (all white). Use CNNEffect in demo for visual validation.
+
+**CNN v2:** ⚠️ Partially functional. Readback works but output differs from HTML validation tool.
+- Loads binary weights from `workspaces/main/weights/cnn_v2_weights.bin`
+- Matches CNNv2Effect architecture
+- **Known Issue:** Visual output differs from `tools/cnn_v2_test/index.html` despite matching shader code
+- Root cause under investigation (weight indexing? texture sampling? activation clamping?)
+- Use HTML tool (`tools/cnn_v2_test/index.html`) for accurate validation
+
+---
+
+## Technical Notes (Readback Fix)
+
+**Original Bug:** Buffer mapping returned `WGPUMapAsyncStatus_Unknown` (status=5)
+
+**Root Cause:** Callback mode mismatch
+- Used `WGPUCallbackMode_WaitAnyOnly` (fires only during `wgpuInstanceWaitAny`)
+- Called `wgpuInstanceProcessEvents` in wait loop (wrong API for this mode)
+- Callback never fired → timeout → empty buffer
+
+**Fix Applied:**
+1. Changed callback mode to `WGPUCallbackMode_AllowProcessEvents`
+2. Replaced `wgpuInstanceProcessEvents` with `wgpuDevicePoll(device, true, nullptr)`
+3. Added pre-mapping device poll to ensure copy completes
+
+**Relevant Code:** `src/gpu/texture_readback.cc` lines 97-110
+
+**Reference:** WebGPU spec - Asynchronous Operations, Callback Modes
+
+---
+
+## Limitations
+
+- **CNN v1:** Produces incorrect output, use for debugging only
+- **Single image:** Batch processing requires shell loop
+- **No real-time preview:** Offline processing only
+- **PNG input:** stb_image (JPEG/PNG/BMP/TGA also supported)
+
+---
+
+## Technical Notes
+
+**CNN v2 f16 decoding:**
+- RGBA32Uint texture stores 8×f16 as 4×u32
+- Custom decoder: extract u16, decode f16→f32, clamp [0,1]→u8
+- Handles denormals, infinity, NaN
+
+**Cross-platform:**
+- macOS, Linux (native WebGPU)
+- Windows (mingw-w64 cross-compile)
+
+**Size impact:**
+- Debug/STRIP_ALL=OFF: compiled
+- STRIP_ALL=ON: 0 bytes (compiled out)
+- FINAL_STRIP=ON: tool not built
diff --git a/cnn_v1/docs/CNN_V1_EFFECT.md b/cnn_v1/docs/CNN_V1_EFFECT.md
new file mode 100644
index 0000000..40f095e
--- /dev/null
+++ b/cnn_v1/docs/CNN_V1_EFFECT.md
@@ -0,0 +1,400 @@
+# CNN Post-Processing Effect
+
+Neural network-based stylization for rendered scenes.
+
+---
+
+## Overview
+
+Trainable convolutional neural network layers for artistic stylization (painterly, sketch, cel-shaded effects) with minimal runtime overhead.
+
+**Key Features:**
+- Position-aware layer 0 (coordinate input for vignetting, edge effects)
+- Multi-layer convolutions (3×3, 5×5, 7×7 kernels) with automatic chaining
+- Original input available to all layers via framebuffer capture
+- Configurable final blend with original scene
+- Modular WGSL shader architecture
+- Hardcoded weights (trained offline via PyTorch)
+- ~5-8 KB binary footprint
+
+---
+
+## Architecture
+
+### RGBD → Grayscale Pipeline
+
+**Input:** RGBD (RGB + inverse depth D=1/z)
+**Output:** Grayscale (1 channel)
+**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1]
+
+**Architecture:**
+- **Inner layers (0..N-2):** Conv2d(7→4) - output RGBD
+- **Final layer (N-1):** Conv2d(7→1) - output grayscale
+
+```wgsl
+// Inner layers: 7→4 (RGBD output, vec4-optimized)
+fn cnn_conv3x3_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32, # Grayscale [-1,1]
+ weights: array<vec4<f32>, 72> # 9 pos × 4 ch × 2 vec4 (8 floats per filter)
+) -> vec4<f32>
+
+// Final layer: 7→1 (grayscale output, vec4-optimized)
+fn cnn_conv3x3_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 18> # 9 pos × 2 vec4 (8 floats per filter)
+) -> f32
+```
+
+**Input normalization:**
+- **fs_main** normalizes textures once: `(tex - 0.5) * 2` → [-1,1]
+- **Conv functions** normalize UV coords: `(uv - 0.5) * 2` → [-1,1]
+- **Grayscale** computed once in fs_main using dot product: `dot(original.rgb, vec3(0.2126, 0.7152, 0.0722))`
+- **Inter-layer data** stays in [-1,1] (no denormalization)
+- **Final output** denormalized for display: `(result + 1.0) * 0.5` → [0,1]
+
+**Activation:** tanh for inner layers (output stays [-1,1]), none for final layer
+
+### Multi-Layer Architecture
+
+CNNEffect supports multi-layer networks via automatic effect chaining:
+
+1. **Timeline specifies total layers**: `CNNEffect layers=3 blend=0.7`
+2. **Compiler expands to chain**: 3 separate CNNEffect instances (layer 0→1→2)
+3. **Framebuffer capture**: Layer 0 captures original input to `"captured_frame"`
+4. **Original input binding**: All layers access original via `@binding(4)`
+5. **Final blend**: Last layer blends result with original: `mix(original, result, 0.7)`
+
+**Framebuffer Capture API:**
+- `Effect::needs_framebuffer_capture()` - effect requests pre-capture
+- MainSequence automatically blits input → `"captured_frame"` auxiliary texture
+- Generic mechanism usable by any effect
+
+### File Structure
+
+```
+src/effects/
+ cnn_effect.h/cc # CNNEffect class + framebuffer capture
+
+workspaces/main/shaders/cnn/
+ cnn_activation.wgsl # tanh, ReLU, sigmoid, leaky_relu
+ cnn_conv3x3.wgsl # 3×3 convolution (standard + coord-aware)
+ cnn_conv5x5.wgsl # 5×5 convolution (standard + coord-aware)
+ cnn_conv7x7.wgsl # 7×7 convolution (standard + coord-aware)
+ cnn_weights_generated.wgsl # Weight arrays (auto-generated by train_cnn.py)
+ cnn_layer.wgsl # Main shader with layer switches (auto-generated by train_cnn.py)
+```
+
+---
+
+## Training Workflow
+
+### 1. Prepare Training Data
+
+Input/target image pairs:
+```
+training/input/img_000.png # RGBA (RGB + alpha)
+training/output/img_000.png # Grayscale target
+```
+
+**Note:** Alpha channel can be depth (1/z) or constant (255). Network learns from RGB primarily.
+
+### 2. Train Network
+
+**Patch-based (Recommended)** - Preserves natural pixel scale:
+```bash
+python3 training/train_cnn.py \
+ --input training/input --target training/output \
+ --patch-size 32 --patches-per-image 64 --detector harris \
+ --layers 3 --kernel-sizes 3,5,3 \
+ --epochs 5000 --batch-size 16 --checkpoint-every 1000
+```
+
+**Detectors:** `harris` (corners), `fast` (features), `shi-tomasi` (corners), `gradient` (edges)
+
+**Full-image (Legacy)** - Resizes to 256×256:
+```bash
+python3 training/train_cnn.py \
+ --input training/input --target training/output \
+ --layers 3 --kernel-sizes 3,5,3 \
+ --epochs 10000 --batch-size 8 --checkpoint-every 1000
+```
+
+**Auto-generates:**
+- `cnn_weights_generated.wgsl` - Weight arrays
+- `cnn_layer.wgsl` - Layer shader
+
+### 3. Export & Validate
+
+```bash
+# Export shaders
+./training/train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth
+
+# Generate ground truth
+./training/train_cnn.py --infer input.png \
+ --export-only checkpoints/checkpoint_epoch_5000.pth --output ground_truth.png
+```
+
+### 4. Rebuild Demo
+
+```bash
+cmake --build build -j4 && ./build/demo64k
+```
+
+---
+
+## Usage
+
+### C++ Integration
+
+**Single layer (manual):**
+```cpp
+#include "effects/cnn_effect.h"
+
+CNNEffectParams p;
+p.layer_index = 0;
+p.total_layers = 1;
+p.blend_amount = 1.0f;
+auto cnn = std::make_shared<CNNEffect>(ctx, p);
+timeline.add_effect(cnn, start_time, end_time);
+```
+
+**Multi-layer (automatic via timeline compiler):**
+
+Use timeline syntax - `seq_compiler` expands to multiple instances.
+
+### Timeline Examples
+
+**Single-layer CNN (full stylization):**
+```
+SEQUENCE 10.0 0
+ EFFECT + Hybrid3DEffect 0.00 5.00
+ EFFECT + CNNEffect 0.50 5.00 layers=1
+```
+
+**Multi-layer CNN with blend:**
+```
+SEQUENCE 10.0 0
+ EFFECT + Hybrid3DEffect 0.00 5.00
+ EFFECT + CNNEffect 0.50 5.00 layers=3 blend=0.7
+```
+
+Expands to:
+```cpp
+// Layer 0 (captures original, blend=1.0)
+{
+ CNNEffectParams p;
+ p.layer_index = 0;
+ p.total_layers = 3;
+ p.blend_amount = 1.0f;
+ seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 1);
+}
+// Layer 1 (blend=1.0)
+{
+ CNNEffectParams p;
+ p.layer_index = 1;
+ p.total_layers = 3;
+ p.blend_amount = 1.0f;
+ seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 2);
+}
+// Layer 2 (final blend=0.7)
+{
+ CNNEffectParams p;
+ p.layer_index = 2;
+ p.total_layers = 3;
+ p.blend_amount = 0.7f;
+ seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 3);
+}
+```
+
+---
+
+## Shader Structure
+
+**Bindings:**
+```wgsl
+@group(0) @binding(0) var smplr: sampler;
+@group(0) @binding(1) var txt: texture_2d<f32>; // Current layer input
+@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;
+@group(0) @binding(3) var<uniform> params: CNNLayerParams;
+@group(0) @binding(4) var original_input: texture_2d<f32>; // Layer 0 input (captured)
+```
+
+**Fragment shader logic:**
+```wgsl
+@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {
+ let uv = p.xy / uniforms.resolution;
+ let original_raw = textureSample(original_input, smplr, uv);
+ let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]
+ let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
+ var result = vec4<f32>(0.0);
+
+ if (params.layer_index == 0) {
+ result = cnn_conv3x3_7to4_src(txt, smplr, uv, uniforms.resolution,
+ weights_layer0);
+ result = cnn_tanh(result);
+ }
+ else if (params.layer_index == 1) {
+ result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution,
+ gray, weights_layer1);
+ result = cnn_tanh(result);
+ }
+ // ... other layers
+
+ // Blend with ORIGINAL input (not previous layer)
+ return mix(original_raw, result, params.blend_amount);
+}
+```
+
+**Weight Storage (vec4-optimized):**
+
+**Inner layers (7→4 RGBD output):**
+```wgsl
+// Structure: array<vec4<f32>, 72>
+// 9 pos × 4 ch × 2 vec4 (8 floats per filter: [rgba][uv,gray,1])
+const weights_layer0: array<vec4<f32>, 72> = array(
+ vec4<f32>(w0_r, w0_g, w0_b, w0_d), // pos0_ch0 (rgba weights)
+ vec4<f32>(w0_u, w0_v, w0_gray, bias0), // pos0_ch0 (uv, gray, bias)
+ vec4<f32>(w1_r, w1_g, w1_b, w1_d), // pos0_ch1 (rgba weights)
+ vec4<f32>(w1_u, w1_v, w1_gray, bias1), // pos0_ch1 (uv, gray, bias)
+ // ... 68 more vec4s
+);
+```
+
+**Final layer (7→1 grayscale output):**
+```wgsl
+// Structure: array<vec4<f32>, 18>
+// 9 pos × 2 vec4 (8 floats per filter: [rgba][uv,gray,1])
+const weights_layerN: array<vec4<f32>, 18> = array(
+ vec4<f32>(w0_r, w0_g, w0_b, w0_d), // pos0 (rgba weights)
+ vec4<f32>(w0_u, w0_v, w0_gray, bias0), // pos0 (uv, gray, bias)
+ // ... 16 more vec4s
+);
+```
+
+**Optimization:** Bias integrated as 4th component via `vec4(uv, gray, 1.0)` input. Two dot4 operations replace 8 scalar MADs.
+
+---
+
+## Size Budget
+
+| Component | Size | Notes |
+|-----------|------|-------|
+| Activation functions | ~200 B | 4 functions |
+| Conv3x3 (standard + coord) | ~500 B | Both variants |
+| Conv5x5 (standard + coord) | ~700 B | Both variants |
+| Conv7x7 (standard + coord) | ~900 B | Both variants |
+| Main shader | ~800 B | Layer composition |
+| C++ implementation | ~300 B | Effect class |
+| **Coord weights** | **+32 B** | Per-layer overhead (layer 0 only) |
+| **RGBA weights** | **2-6 KB** | Depends on depth/kernel sizes |
+| **Total** | **5-9 KB** | Acceptable for 64k |
+
+**Optimization strategies:**
+- Quantize weights (float32 → int8)
+- Prune near-zero weights
+- Use separable convolutions
+
+---
+
+## Testing
+
+```bash
+./build/test_demo_effects # CNN construction/shader tests
+./build/demo64k # Visual test
+```
+
+---
+
+## Blend Parameter Behavior
+
+**blend_amount** controls final compositing with original:
+- `blend=0.0`: Pure original (no CNN effect)
+- `blend=0.5`: 50% original + 50% CNN
+- `blend=1.0`: Pure CNN output (full stylization)
+
+**Important:** Blend uses captured layer 0 input, not previous layer output.
+
+**Example use cases:**
+- `blend=1.0`: Full stylization (default)
+- `blend=0.7`: Subtle effect preserving original details
+- `blend=0.3`: Light artistic touch
+
+## Troubleshooting
+
+**Shader compilation fails:**
+- Check `cnn_weights_generated.wgsl` syntax
+- Verify snippets registered in `shaders.cc::InitShaderComposer()`
+- Ensure `cnn_layer.wgsl` has 5 bindings (including `original_input`)
+
+**Black/corrupted output:**
+- Weights untrained (identity placeholder)
+- Check `captured_frame` auxiliary texture is registered
+- Verify layer priorities in timeline are sequential
+
+**Wrong blend result:**
+- Ensure layer 0 has `needs_framebuffer_capture() == true`
+- Check MainSequence framebuffer capture logic
+- Verify `original_input` binding is populated
+
+**Training loss not decreasing:**
+- Lower learning rate (`--learning-rate 0.0001`)
+- More epochs (`--epochs 1000`)
+- Check input/target image alignment
+
+---
+
+## Vec4 Optimization
+
+**Architecture:** Weights stored as vec4 pairs for SIMD efficiency.
+
+**Input representation:**
+```wgsl
+let rgbd = textureSample(...); // vec4: [r, g, b, d]
+let in1 = vec4<f32>(uv_norm, gray, 1.0); // vec4: [u, v, gray, 1.0]
+```
+
+**Weight indexing:**
+```wgsl
+var pos = 0; // Direct weight array index
+for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ // Unrolled channel loop (4 output channels)
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8; // 4 channels × 2 vec4s per channel
+ }
+}
+```
+
+**Benefits:**
+- **SIMD-native:** GPU executes `dot(vec4, vec4)` as single instruction (4 parallel MADs)
+- **Memory bandwidth:** 2 vec4 loads vs 8 scalar loads (better cache alignment)
+- **Bias integration:** Free via `[..., 1.0]` component (no separate add)
+- **Code simplicity:** Eliminates inner loop, direct indexing with `pos`
+- **Performance:** 2-3× GPU throughput improvement over scalar version
+
+**Weight layout per filter (8 floats):**
+- vec4[0]: [w_r, w_g, w_b, w_d] (rgba input weights)
+- vec4[1]: [w_u, w_v, w_gray, bias] (uv, grayscale, bias)
+
+**3×3 kernel sizes:**
+- Inner layer (7→4): 72 vec4s (9 pos × 4 ch × 2 vec4 = 2304 bytes)
+- Final layer (7→1): 18 vec4s (9 pos × 1 ch × 2 vec4 = 288 bytes)
+
+---
+
+## References
+
+- **Training Script:** `training/train_cnn.py`
+- **Shader Composition:** `doc/SEQUENCE.md`
+- **Effect System:** `src/gpu/effect.h`
diff --git a/cnn_v1/shaders/cnn_activation.wgsl b/cnn_v1/shaders/cnn_activation.wgsl
new file mode 100644
index 0000000..4fe771e
--- /dev/null
+++ b/cnn_v1/shaders/cnn_activation.wgsl
@@ -0,0 +1,18 @@
+// CNN activation functions
+// 4 functions: tanh, ReLU, sigmoid, leaky_relu
+
+fn cnn_tanh(x: vec4<f32>) -> vec4<f32> {
+ return tanh(x);
+}
+
+fn cnn_relu(x: vec4<f32>) -> vec4<f32> {
+ return max(vec4<f32>(0.0), x);
+}
+
+fn cnn_sigmoid(x: vec4<f32>) -> vec4<f32> {
+ return 1.0 / (1.0 + exp(-x));
+}
+
+fn cnn_leaky_relu(x: vec4<f32>, alpha: f32) -> vec4<f32> {
+ return max(alpha * x, x);
+}
diff --git a/cnn_v1/shaders/cnn_conv1x1.wgsl b/cnn_v1/shaders/cnn_conv1x1.wgsl
new file mode 100644
index 0000000..f77cfa8
--- /dev/null
+++ b/cnn_v1/shaders/cnn_conv1x1.wgsl
@@ -0,0 +1,100 @@
+// 1x1 convolution (vec4-optimized)
+
+// Inner layers: 7→4 channels (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+fn cnn_conv1x1_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 8>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Source layer: 7→4 channels (vec4-optimized)
+// Normalizes [0,1] input to [-1,1] internally
+fn cnn_conv1x1_7to4_src(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ weights: array<vec4<f32>, 8>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ var original = (textureSample(tex, samp, uv) - 0.5) * 2.0;
+ let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ var rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Final layer: 7→1 channel (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+// Returns raw sum (activation applied at call site)
+fn cnn_conv1x1_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 2>
+) -> f32 {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = 0.0;
+ var pos = 0;
+
+ for (var dy = -0; dy <= 0; dy++) {
+ for (var dx = -0; dx <= 0; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+
+ sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
+ pos += 2;
+ }
+ }
+
+ return sum;
+}
diff --git a/cnn_v1/shaders/cnn_conv3x3.wgsl b/cnn_v1/shaders/cnn_conv3x3.wgsl
new file mode 100644
index 0000000..f7d11b1
--- /dev/null
+++ b/cnn_v1/shaders/cnn_conv3x3.wgsl
@@ -0,0 +1,100 @@
+// 3x3 convolution (vec4-optimized)
+
+// Inner layers: 7→4 channels (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+fn cnn_conv3x3_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 72>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Source layer: 7→4 channels (vec4-optimized)
+// Normalizes [0,1] input to [-1,1] internally
+fn cnn_conv3x3_7to4_src(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ weights: array<vec4<f32>, 72>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;
+ let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// Final layer: 7→1 channel (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+// Returns raw sum (activation applied at call site)
+fn cnn_conv3x3_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 18>
+) -> f32 {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = 0.0;
+ var pos = 0;
+
+ for (var dy = -1; dy <= 1; dy++) {
+ for (var dx = -1; dx <= 1; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+
+ sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
+ pos += 2;
+ }
+ }
+
+ return sum;
+}
diff --git a/cnn_v1/shaders/cnn_conv5x5.wgsl b/cnn_v1/shaders/cnn_conv5x5.wgsl
new file mode 100644
index 0000000..9328d75
--- /dev/null
+++ b/cnn_v1/shaders/cnn_conv5x5.wgsl
@@ -0,0 +1,101 @@
+// 5×5 variant for 7→4 channels (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+// UV coordinates remain in [0,1] and are normalized internally
+// weights: array<vec4<f32>, 200> (25 pos × 4 ch × 2 vec4)
+fn cnn_conv5x5_7to4(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 200>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
+
+// 5×5 variant for 7→1 channel (vec4-optimized)
+// Assumes 'tex' is already normalized to [-1,1]
+// UV coordinates remain in [0,1] and are normalized internally
+// weights: array<vec4<f32>, 50> (25 pos × 2 vec4)
+fn cnn_conv5x5_7to1(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ gray: f32,
+ weights: array<vec4<f32>, 50>
+) -> f32 {
+ let step = 1.0 / resolution;
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = 0.0;
+ var pos = 0;
+
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = textureSample(tex, samp, uv + offset);
+
+ sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);
+ pos += 2;
+ }
+ }
+
+ return sum;
+}
+
+// Source layer: 7→4 channels (vec4-optimized)
+// Normalizes [0,1] input to [-1,1] internally
+fn cnn_conv5x5_7to4_src(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ weights: array<vec4<f32>, 200>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+
+ let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;
+ let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));
+ let uv_norm = (uv - 0.5) * 2.0;
+ let in1 = vec4<f32>(uv_norm, gray, 1.0);
+
+ var sum = vec4<f32>(0.0);
+ var pos = 0;
+
+ for (var dy = -2; dy <= 2; dy++) {
+ for (var dx = -2; dx <= 2; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;
+
+ sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);
+ sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);
+ sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);
+ sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);
+ pos += 8;
+ }
+ }
+
+ return sum;
+}
diff --git a/cnn_v1/shaders/cnn_conv7x7.wgsl b/cnn_v1/shaders/cnn_conv7x7.wgsl
new file mode 100644
index 0000000..e68d644
--- /dev/null
+++ b/cnn_v1/shaders/cnn_conv7x7.wgsl
@@ -0,0 +1,53 @@
+// 7x7 convolution with 49 samples
+// Applies mat4 weights per sample
+
+fn cnn_conv7x7(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ weights: array<mat4x4<f32>, 49>,
+ bias: vec4<f32>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ var sum = bias;
+ var idx = 0;
+
+ for (var dy = -3; dy <= 3; dy++) {
+ for (var dx = -3; dx <= 3; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let sample = textureSample(tex, samp, uv + offset);
+ sum += weights[idx] * sample;
+ idx++;
+ }
+ }
+
+ return sum;
+}
+
+fn cnn_conv7x7_with_coord(
+ tex: texture_2d<f32>,
+ samp: sampler,
+ uv: vec2<f32>,
+ resolution: vec2<f32>,
+ rgba_weights: array<mat4x4<f32>, 49>,
+ coord_weights: mat2x4<f32>,
+ bias: vec4<f32>
+) -> vec4<f32> {
+ let step = 1.0 / resolution;
+ var sum = bias;
+
+ sum += coord_weights * uv;
+
+ var idx = 0;
+ for (var dy = -3; dy <= 3; dy++) {
+ for (var dx = -3; dx <= 3; dx++) {
+ let offset = vec2<f32>(f32(dx), f32(dy)) * step;
+ let rgba = textureSample(tex, samp, uv + offset);
+ sum += rgba_weights[idx] * rgba;
+ idx++;
+ }
+ }
+
+ return sum;
+}
diff --git a/cnn_v1/shaders/cnn_layer.wgsl b/cnn_v1/shaders/cnn_layer.wgsl
new file mode 100644
index 0000000..cbd1686
--- /dev/null
+++ b/cnn_v1/shaders/cnn_layer.wgsl
@@ -0,0 +1,55 @@
+// CNN layer shader - uses modular convolution snippets
+// Supports multi-pass rendering with residual connections
+// DO NOT EDIT - Generated by train_cnn.py
+
+@group(0) @binding(0) var smplr: sampler;
+@group(0) @binding(1) var txt: texture_2d<f32>;
+
+#include "common_uniforms"
+#include "cnn_activation"
+#include "cnn_conv3x3"
+#include "cnn_conv5x5"
+#include "cnn_weights_generated"
+
+struct CNNLayerParams {
+ layer_index: i32,
+ blend_amount: f32,
+ _pad: vec2<f32>,
+};
+
+@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;
+@group(0) @binding(3) var<uniform> params: CNNLayerParams;
+@group(0) @binding(4) var original_input: texture_2d<f32>;
+
+@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {
+ var pos = array<vec2<f32>, 3>(
+ vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)
+ );
+ return vec4<f32>(pos[i], 0.0, 1.0);
+}
+
+@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {
+ // Match PyTorch linspace
+ let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);
+ let original_raw = textureSample(original_input, smplr, uv);
+ let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]
+ let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;
+ var result = vec4<f32>(0.0);
+
+ // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)
+ if (params.layer_index == 0) {
+ result = cnn_conv5x5_7to4_src(txt, smplr, uv, uniforms.resolution, weights_layer0);
+ result = cnn_tanh(result);
+ }
+ else if (params.layer_index == 1) {
+ result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution, gray, weights_layer1);
+ result = cnn_tanh(result); // Keep in [-1,1]
+ }
+ else if (params.layer_index == 2) {
+ let sum = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution, gray, weights_layer2);
+ let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation
+ result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);
+ return mix(original_raw, result, params.blend_amount); // [0,1]
+ }
+ return result; // [-1,1]
+}
diff --git a/cnn_v1/shaders/cnn_weights_generated.wgsl b/cnn_v1/shaders/cnn_weights_generated.wgsl
new file mode 100644
index 0000000..510f86f
--- /dev/null
+++ b/cnn_v1/shaders/cnn_weights_generated.wgsl
@@ -0,0 +1,302 @@
+// Auto-generated CNN weights (vec4-optimized)
+// DO NOT EDIT - Generated by train_cnn.py
+
+const weights_layer0: array<vec4<f32>, 200> = array(
+ vec4<f32>(0.235493, 0.070711, -0.007171, 0.029242),
+ vec4<f32>(0.010796, -0.007094, 0.104870, -0.001741),
+ vec4<f32>(-0.363645, 0.625662, 0.044248, 0.046890),
+ vec4<f32>(0.016731, -0.099652, 0.198682, -0.002050),
+ vec4<f32>(-0.738196, -1.196639, -0.153794, 0.059818),
+ vec4<f32>(-0.012392, 0.206094, -1.159788, 0.001624),
+ vec4<f32>(-0.089846, -0.097056, 0.533546, -0.256308),
+ vec4<f32>(0.052460, 0.007740, -0.025518, -0.011569),
+ vec4<f32>(0.024563, -0.123127, -0.189236, -0.034605),
+ vec4<f32>(0.027494, 0.077022, -0.073083, -0.001741),
+ vec4<f32>(0.127897, -1.191688, -0.289229, -0.057213),
+ vec4<f32>(-0.017651, -0.095915, -0.540725, -0.002050),
+ vec4<f32>(0.459141, 1.047422, 1.008783, 0.082279),
+ vec4<f32>(-0.148789, 0.141891, 0.964934, 0.001624),
+ vec4<f32>(-0.458732, -0.253084, 0.429181, -0.267647),
+ vec4<f32>(0.029582, 0.043901, -0.332350, -0.011569),
+ vec4<f32>(-0.089206, -0.379760, -0.267976, -0.033062),
+ vec4<f32>(-0.059616, 0.042331, -0.297211, -0.001741),
+ vec4<f32>(0.347450, 0.349807, -0.107598, -0.038193),
+ vec4<f32>(-0.054979, -0.022737, 0.368773, -0.002050),
+ vec4<f32>(1.185666, 2.203693, 1.743948, 0.015765),
+ vec4<f32>(-0.004807, 0.138734, 2.114184, 0.001624),
+ vec4<f32>(-0.397312, -0.423930, 0.436068, -0.309529),
+ vec4<f32>(-0.025822, 0.061618, -0.358850, -0.011569),
+ vec4<f32>(0.031591, -0.133625, -0.210201, -0.058735),
+ vec4<f32>(0.026377, 0.074180, -0.075918, -0.001741),
+ vec4<f32>(-0.632064, -0.365984, -0.183357, -0.064294),
+ vec4<f32>(-0.038233, -0.027135, -0.529794, -0.002050),
+ vec4<f32>(-0.079942, -0.108489, 0.284420, 0.068003),
+ vec4<f32>(-0.033783, 0.131316, -0.006431, 0.001624),
+ vec4<f32>(-0.096003, -0.037157, 0.523401, -0.332369),
+ vec4<f32>(0.098362, 0.049597, 0.024988, -0.011569),
+ vec4<f32>(-0.042374, 0.215371, 0.044488, -0.079190),
+ vec4<f32>(-0.108483, 0.244548, 0.195395, -0.001741),
+ vec4<f32>(0.121079, 0.214838, 0.292411, -0.013912),
+ vec4<f32>(0.098564, -0.117552, 0.392438, -0.002050),
+ vec4<f32>(-0.994368, -0.526871, 0.165568, 0.006371),
+ vec4<f32>(-0.142932, 0.234835, -0.612723, 0.001624),
+ vec4<f32>(-0.430247, -0.230031, 0.035994, -0.340101),
+ vec4<f32>(-0.134622, -0.045299, -0.264801, -0.011569),
+ vec4<f32>(-0.116651, 0.042012, -0.004781, 0.018667),
+ vec4<f32>(0.000405, -0.068494, 0.084279, -0.001741),
+ vec4<f32>(0.180754, -0.853766, -0.384955, 0.013426),
+ vec4<f32>(0.038369, 0.010519, -0.437544, -0.002050),
+ vec4<f32>(0.373661, 0.677625, 0.617145, -0.028541),
+ vec4<f32>(0.071383, 0.012678, 0.734573, 0.001624),
+ vec4<f32>(-0.187586, -0.167658, 0.445526, -0.213674),
+ vec4<f32>(-0.054012, -0.048233, -0.111101, -0.011569),
+ vec4<f32>(-0.329708, 0.124956, 0.150447, 0.038372),
+ vec4<f32>(0.042139, -0.014901, 0.056693, -0.001741),
+ vec4<f32>(0.547166, 1.493724, 0.572366, 0.044038),
+ vec4<f32>(-0.055818, 0.022352, 1.209448, -0.002050),
+ vec4<f32>(-0.669255, -0.481531, -0.593402, 0.125846),
+ vec4<f32>(-0.086191, -0.012315, -0.692654, 0.001624),
+ vec4<f32>(-0.667836, -0.543086, 0.253854, -0.236805),
+ vec4<f32>(0.045048, 0.047535, -0.607491, -0.011569),
+ vec4<f32>(-0.262418, 0.247133, 0.225155, -0.084126),
+ vec4<f32>(0.017065, 0.007371, 0.103683, -0.001741),
+ vec4<f32>(0.216644, 1.179116, 0.436799, 0.041116),
+ vec4<f32>(0.006571, 0.012147, 0.674660, -0.002050),
+ vec4<f32>(0.290965, -0.022340, -0.616338, 0.021808),
+ vec4<f32>(-0.091234, -0.016764, 0.116976, 0.001624),
+ vec4<f32>(-0.689736, -0.685681, 0.342797, -0.213249),
+ vec4<f32>(0.040683, 0.038921, -0.663171, -0.011569),
+ vec4<f32>(-0.150412, 0.018053, -0.103426, 0.026070),
+ vec4<f32>(0.016183, -0.090006, 0.028738, -0.001741),
+ vec4<f32>(0.851827, -0.499315, 0.146696, 0.047324),
+ vec4<f32>(0.059725, 0.031269, 0.184268, -0.002050),
+ vec4<f32>(0.160719, -0.309456, -0.432633, -0.021171),
+ vec4<f32>(-0.060075, -0.052701, -0.248520, 0.001624),
+ vec4<f32>(-0.217727, 0.354527, 0.663356, -0.267530),
+ vec4<f32>(-0.032714, 0.000761, 0.246687, -0.011569),
+ vec4<f32>(0.077123, 0.069934, 0.077986, 0.004388),
+ vec4<f32>(-0.107897, 0.103689, 0.072698, -0.001741),
+ vec4<f32>(-0.216285, -0.206663, -0.497913, -0.019433),
+ vec4<f32>(0.042063, -0.036315, -0.306115, -0.002050),
+ vec4<f32>(0.351038, 0.116104, -0.046132, 0.022280),
+ vec4<f32>(-0.026460, -0.025197, 0.286924, 0.001624),
+ vec4<f32>(-0.480131, -0.253209, -0.259724, -0.353796),
+ vec4<f32>(-0.069436, -0.026651, -0.285359, -0.011569),
+ vec4<f32>(0.225811, -0.092313, -0.152689, 0.007505),
+ vec4<f32>(0.120530, 0.012846, -0.020303, -0.001741),
+ vec4<f32>(0.305262, 0.699468, 0.474383, -0.002565),
+ vec4<f32>(-0.036377, 0.008052, 0.424588, -0.002050),
+ vec4<f32>(0.557323, 0.489104, 0.312243, 0.072877),
+ vec4<f32>(0.096476, -0.012612, 0.586454, 0.001624),
+ vec4<f32>(-0.370964, -0.252666, 0.235903, -0.299915),
+ vec4<f32>(-0.066341, -0.008435, -0.158507, -0.011569),
+ vec4<f32>(0.070604, -0.016186, -0.079075, 0.015055),
+ vec4<f32>(0.042533, -0.085281, -0.014053, -0.001741),
+ vec4<f32>(-1.115748, -0.531544, -0.207050, -0.040691),
+ vec4<f32>(0.010035, -0.008330, -0.718958, -0.002050),
+ vec4<f32>(-1.404958, -2.000416, -1.884062, 0.014171),
+ vec4<f32>(0.019375, -0.078894, -1.999592, 0.001624),
+ vec4<f32>(-1.144367, -0.681485, 0.145197, -0.310542),
+ vec4<f32>(0.071912, -0.001021, -0.817277, -0.011569),
+ vec4<f32>(-0.018298, 0.109930, -0.067419, -0.031281),
+ vec4<f32>(0.072086, -0.047123, -0.018405, -0.001741),
+ vec4<f32>(-2.926982, -5.479454, -1.936543, 0.034851),
+ vec4<f32>(0.005592, 0.052238, -4.695754, -0.002050),
+ vec4<f32>(0.504616, -0.384917, -0.623795, 0.009371),
+ vec4<f32>(-0.105685, -0.049385, -0.154266, 0.001624),
+ vec4<f32>(-1.428979, -0.829611, 0.160294, -0.239524),
+ vec4<f32>(0.054180, -0.058797, -0.939519, -0.011569),
+ vec4<f32>(0.088147, -0.158820, -0.199674, -0.083067),
+ vec4<f32>(0.073984, -0.059593, -0.103344, -0.001741),
+ vec4<f32>(0.465084, 2.259005, 0.899806, -0.010464),
+ vec4<f32>(0.058231, -0.075668, 1.383652, -0.002050),
+ vec4<f32>(-0.162736, -0.899540, -0.559890, 0.066380),
+ vec4<f32>(0.029594, 0.036117, -0.780812, 0.001624),
+ vec4<f32>(-0.605431, 0.342970, 0.671602, -0.313734),
+ vec4<f32>(0.072950, 0.058100, 0.232742, -0.011569),
+ vec4<f32>(0.161941, -0.017279, -0.010904, -0.041589),
+ vec4<f32>(-0.118079, 0.090886, 0.001212, -0.001741),
+ vec4<f32>(-0.136354, 0.155269, 0.058437, -0.043499),
+ vec4<f32>(0.029368, 0.079326, -0.060807, -0.002050),
+ vec4<f32>(0.222824, 0.267939, 0.010260, 0.093258),
+ vec4<f32>(-0.091763, 0.028527, 0.290062, 0.001624),
+ vec4<f32>(-0.584501, -0.074002, -0.187352, -0.247388),
+ vec4<f32>(-0.067679, -0.036398, -0.237425, -0.011569),
+ vec4<f32>(-0.026121, -0.231360, 0.002505, -0.096021),
+ vec4<f32>(0.073173, -0.059323, -0.128630, -0.001741),
+ vec4<f32>(-0.118509, -0.931686, -0.328151, 0.027222),
+ vec4<f32>(0.006670, -0.094619, -0.605555, -0.002050),
+ vec4<f32>(0.260254, 0.186958, 0.235441, -0.030871),
+ vec4<f32>(0.111987, -0.056380, 0.227175, 0.001624),
+ vec4<f32>(0.012446, -0.068683, 0.273271, -0.315052),
+ vec4<f32>(-0.020011, 0.046984, 0.026316, -0.011569),
+ vec4<f32>(0.149830, 0.108146, 0.141757, 0.040947),
+ vec4<f32>(-0.060874, -0.004303, 0.196782, -0.001741),
+ vec4<f32>(1.031257, 1.493831, 0.443644, -0.089572),
+ vec4<f32>(-0.035087, 0.049431, 1.193984, -0.002050),
+ vec4<f32>(-0.204666, -0.340174, -0.045684, 0.053997),
+ vec4<f32>(0.000214, -0.073696, -0.299299, 0.001624),
+ vec4<f32>(-1.040674, -0.828753, 0.007912, -0.326534),
+ vec4<f32>(0.040669, -0.036526, -0.794626, -0.011569),
+ vec4<f32>(-0.018212, -0.031610, 0.259871, -0.041978),
+ vec4<f32>(0.021055, -0.061307, -0.004348, -0.001741),
+ vec4<f32>(0.002720, 0.570871, 0.371837, -0.076940),
+ vec4<f32>(0.023420, 0.006175, 0.318983, -0.002050),
+ vec4<f32>(0.259713, 0.294528, 0.907401, 0.043367),
+ vec4<f32>(-0.087576, -0.053953, 0.273380, 0.001624),
+ vec4<f32>(-1.177213, -0.464727, 0.211285, -0.266637),
+ vec4<f32>(0.075274, -0.007404, -0.703821, -0.011569),
+ vec4<f32>(-0.089204, -0.053316, 0.280138, -0.056155),
+ vec4<f32>(0.030981, -0.005136, 0.038455, -0.001741),
+ vec4<f32>(0.936459, -0.196866, 0.270033, -0.096884),
+ vec4<f32>(0.025329, -0.032176, 0.473732, -0.002050),
+ vec4<f32>(0.312348, 0.234105, 0.580837, 0.099177),
+ vec4<f32>(0.019877, -0.096514, 0.450075, 0.001624),
+ vec4<f32>(-1.099700, -0.203693, 0.157253, -0.331450),
+ vec4<f32>(-0.033353, -0.072074, -0.453590, -0.011569),
+ vec4<f32>(-0.084598, -0.039735, 0.162495, -0.070988),
+ vec4<f32>(-0.038491, 0.071525, 0.034601, -0.001741),
+ vec4<f32>(-0.199528, -0.475454, -0.297979, 0.037322),
+ vec4<f32>(-0.003106, 0.003258, -0.475664, -0.002050),
+ vec4<f32>(-0.282845, 0.058921, -0.300971, -0.011632),
+ vec4<f32>(-0.102320, 0.065302, -0.035173, 0.001624),
+ vec4<f32>(-0.515296, 0.497936, 0.313751, -0.245144),
+ vec4<f32>(-0.126936, 0.016721, 0.233370, -0.011569),
+ vec4<f32>(-0.220154, 0.069414, 0.194344, 0.000786),
+ vec4<f32>(0.037788, -0.095021, -0.055585, -0.001741),
+ vec4<f32>(-0.186244, 0.434960, 0.138978, -0.017604),
+ vec4<f32>(0.014466, 0.055976, 0.306540, -0.002050),
+ vec4<f32>(0.000614, -0.087365, -0.327816, 0.025776),
+ vec4<f32>(0.227096, -0.143725, -0.046319, 0.001624),
+ vec4<f32>(0.468607, -0.441809, -0.025186, -0.260166),
+ vec4<f32>(0.018770, -0.067388, -0.240128, -0.011569),
+ vec4<f32>(-0.013968, 0.032027, -0.111361, -0.023976),
+ vec4<f32>(0.041929, -0.033460, 0.001994, -0.001741),
+ vec4<f32>(0.005203, -0.837762, -0.287991, -0.026139),
+ vec4<f32>(-0.077592, 0.021388, -0.524153, -0.002050),
+ vec4<f32>(0.250865, 0.313428, -0.248465, 0.059517),
+ vec4<f32>(0.034922, -0.054528, 0.257107, 0.001624),
+ vec4<f32>(0.010692, -0.067238, 0.233031, -0.310017),
+ vec4<f32>(0.176915, -0.059644, 0.016072, -0.011569),
+ vec4<f32>(0.016422, 0.016187, -0.037382, -0.083725),
+ vec4<f32>(0.002691, -0.110865, -0.012957, -0.001741),
+ vec4<f32>(0.095561, 0.396829, 0.128803, 0.037097),
+ vec4<f32>(0.019823, 0.093399, 0.310928, -0.002050),
+ vec4<f32>(-0.193791, -0.079385, 0.332894, 0.039734),
+ vec4<f32>(0.119291, -0.053947, 0.020449, 0.001624),
+ vec4<f32>(-0.446965, -0.003325, 0.231982, -0.298212),
+ vec4<f32>(0.063248, -0.060392, -0.103558, -0.011569),
+ vec4<f32>(-0.044501, -0.246630, -0.254448, -0.025872),
+ vec4<f32>(0.044620, -0.074284, -0.183828, -0.001741),
+ vec4<f32>(-0.369636, -0.171104, -0.485456, -0.085980),
+ vec4<f32>(-0.053131, 0.016452, -0.377567, -0.002050),
+ vec4<f32>(-0.183644, -0.028271, 0.226453, 0.010102),
+ vec4<f32>(0.039391, -0.132828, -0.009034, 0.001624),
+ vec4<f32>(-0.644046, -0.335421, 0.011161, -0.222670),
+ vec4<f32>(0.091183, 0.005457, -0.472058, -0.011569),
+ vec4<f32>(0.045107, 0.080623, -0.132791, 0.064920),
+ vec4<f32>(-0.110745, 0.109524, 0.092569, -0.001741),
+ vec4<f32>(0.064397, 0.190407, 0.257845, 0.024637),
+ vec4<f32>(-0.042557, 0.128625, 0.317239, -0.002050),
+ vec4<f32>(-0.362482, 0.271381, -0.115412, 0.103104),
+ vec4<f32>(0.088766, 0.042583, 0.069687, 0.001624),
+ vec4<f32>(-0.353634, 0.554832, 0.442496, -0.351794),
+ vec4<f32>(-0.140207, -0.064649, 0.346336, -0.011569)
+);
+
+const weights_layer1: array<vec4<f32>, 72> = array(
+ vec4<f32>(-0.059078, -0.087833, -0.048345, -0.276761),
+ vec4<f32>(-0.101904, 0.058647, -0.405575, -0.064215),
+ vec4<f32>(-0.382952, 0.579364, -0.051813, -0.155723),
+ vec4<f32>(-0.140997, -0.006771, 0.212267, 0.120289),
+ vec4<f32>(-0.152651, -0.134768, -0.076617, -0.506104),
+ vec4<f32>(0.089304, 0.078492, 0.541122, 0.129289),
+ vec4<f32>(0.739323, -0.014103, -0.012980, -0.112747),
+ vec4<f32>(-0.089971, -0.088661, -0.520901, 0.158290),
+ vec4<f32>(0.819725, 2.866048, 0.080441, 0.380885),
+ vec4<f32>(0.035196, 0.028422, -0.748029, -0.064215),
+ vec4<f32>(-0.551722, 0.995924, -0.203047, -0.220742),
+ vec4<f32>(-0.081721, 0.039584, 0.581791, 0.120289),
+ vec4<f32>(-0.752329, -0.482903, -0.317275, 0.515372),
+ vec4<f32>(-0.087637, 0.040969, 0.481261, 0.129289),
+ vec4<f32>(0.532382, -0.653574, 0.078268, 0.139585),
+ vec4<f32>(-0.089350, -0.072701, -1.289249, 0.158290),
+ vec4<f32>(0.384272, -0.051717, 0.428463, -0.006561),
+ vec4<f32>(0.034003, 0.036653, -0.778556, -0.064215),
+ vec4<f32>(-0.788796, 0.332339, -0.181283, -0.213141),
+ vec4<f32>(0.196044, -0.062422, 0.724631, 0.120289),
+ vec4<f32>(-0.416297, -0.520778, -0.009510, -0.304383),
+ vec4<f32>(0.094475, -0.033135, 0.942838, 0.129289),
+ vec4<f32>(0.887455, 0.054078, 0.193434, 0.268549),
+ vec4<f32>(-0.055369, -0.042953, -0.172902, 0.158290),
+ vec4<f32>(0.419144, -0.159019, 0.189637, -0.235703),
+ vec4<f32>(-0.098285, 0.021026, -0.041846, -0.064215),
+ vec4<f32>(-1.009575, 0.934207, -0.120383, -0.243756),
+ vec4<f32>(-0.054562, 0.123804, 0.004157, 0.120289),
+ vec4<f32>(-0.504099, 0.696545, -0.850290, 0.493131),
+ vec4<f32>(-0.090043, -0.020600, -1.148702, 0.129289),
+ vec4<f32>(0.302269, -0.662429, 0.315052, -0.276341),
+ vec4<f32>(-0.084626, -0.029208, -0.799132, 0.158290),
+ vec4<f32>(0.318365, 2.531235, 0.349606, 0.231242),
+ vec4<f32>(0.053525, -0.031474, -0.570432, -0.064215),
+ vec4<f32>(-0.635031, 0.498836, 0.009884, -0.465079),
+ vec4<f32>(0.059087, 0.038415, 0.009928, 0.120289),
+ vec4<f32>(-0.522592, -3.781285, 0.418296, -0.608186),
+ vec4<f32>(0.100879, -0.083891, 1.653884, 0.129289),
+ vec4<f32>(0.258571, 2.590279, 0.221239, -0.143175),
+ vec4<f32>(0.121409, -0.084177, -1.397735, 0.158290),
+ vec4<f32>(0.907284, -0.034063, 0.573987, -0.125626),
+ vec4<f32>(-0.017610, -0.059485, -0.242599, -0.064215),
+ vec4<f32>(-0.748146, 0.686047, -0.074510, -0.248879),
+ vec4<f32>(-0.034986, -0.121423, -0.406087, 0.120289),
+ vec4<f32>(-0.559352, -2.921763, -0.718019, -0.764524),
+ vec4<f32>(0.165658, 0.097044, 0.773885, 0.129289),
+ vec4<f32>(0.006276, -0.801820, 0.215264, 0.115919),
+ vec4<f32>(0.081513, -0.023028, -0.590423, 0.158290),
+ vec4<f32>(-0.207850, 0.088171, -0.173170, 0.351969),
+ vec4<f32>(-0.042732, -0.024059, -0.087492, -0.064215),
+ vec4<f32>(-0.711148, 0.312318, -0.145549, -0.113749),
+ vec4<f32>(0.053038, 0.093166, -0.473856, 0.120289),
+ vec4<f32>(-0.343481, -0.137305, -0.340862, 0.445920),
+ vec4<f32>(-0.070473, -0.024914, -0.735660, 0.129289),
+ vec4<f32>(0.212955, -0.200508, 0.105125, -0.165284),
+ vec4<f32>(-0.123633, 0.052941, 0.099918, 0.158290),
+ vec4<f32>(0.362468, -0.709693, 0.281097, -0.155976),
+ vec4<f32>(-0.034566, 0.002014, 0.443026, -0.064215),
+ vec4<f32>(-0.346208, 1.179972, -0.563868, -0.424647),
+ vec4<f32>(0.012676, -0.023351, -0.703819, 0.120289),
+ vec4<f32>(-0.476282, -0.001002, -0.456911, -0.143433),
+ vec4<f32>(0.061018, -0.051173, -0.992671, 0.129289),
+ vec4<f32>(0.340925, -0.869046, 0.333377, -0.070414),
+ vec4<f32>(0.022279, 0.022837, -0.389711, 0.158290),
+ vec4<f32>(0.217347, -0.092030, -0.004346, 0.209850),
+ vec4<f32>(-0.116637, -0.096003, -0.333961, -0.064215),
+ vec4<f32>(-0.105262, 0.443411, -0.443104, 0.032732),
+ vec4<f32>(0.014939, 0.058855, -0.723723, 0.120289),
+ vec4<f32>(-0.598907, -0.166341, -0.635385, 0.463685),
+ vec4<f32>(0.151976, 0.049510, 0.155364, 0.129289),
+ vec4<f32>(0.138981, -0.109141, 0.272429, 0.190495),
+ vec4<f32>(-0.005729, 0.020860, -0.062157, 0.158290)
+);
+
+const weights_layer2: array<vec4<f32>, 18> = array(
+ vec4<f32>(0.043207, -0.056041, 0.131565, 0.116278),
+ vec4<f32>(-0.038849, -0.028105, -0.112979, 0.023741),
+ vec4<f32>(-0.010112, -0.085145, 0.257510, 0.245113),
+ vec4<f32>(0.041108, 0.049255, -0.082008, 0.023741),
+ vec4<f32>(0.012368, -0.035856, 0.018924, 0.174452),
+ vec4<f32>(0.052554, 0.039427, -0.279445, 0.023741),
+ vec4<f32>(-0.160061, -0.232735, 0.256951, 0.208887),
+ vec4<f32>(-0.088352, 0.100106, 0.103566, 0.023741),
+ vec4<f32>(-0.406607, -1.336396, 0.454171, 0.310834),
+ vec4<f32>(-0.061166, 0.105463, 1.572779, 0.023741),
+ vec4<f32>(-0.188413, -0.523344, 0.082813, 0.209113),
+ vec4<f32>(0.052509, -0.069748, -0.065008, 0.023741),
+ vec4<f32>(-0.124016, 0.005237, 0.177859, 0.138953),
+ vec4<f32>(0.072167, 0.070582, -0.209545, 0.023741),
+ vec4<f32>(-0.384457, -0.186386, 0.273595, 0.235457),
+ vec4<f32>(-0.032392, -0.086899, -0.006561, 0.023741),
+ vec4<f32>(-0.195800, 0.017395, 0.023080, 0.181437),
+ vec4<f32>(-0.035524, -0.095398, -0.204917, 0.023741)
+);
+
diff --git a/cnn_v1/src/cnn_v1_effect.cc b/cnn_v1/src/cnn_v1_effect.cc
new file mode 100644
index 0000000..1f44619
--- /dev/null
+++ b/cnn_v1/src/cnn_v1_effect.cc
@@ -0,0 +1,129 @@
+// CNN post-processing effect implementation
+// Neural network-based stylization with modular WGSL
+
+#include "cnn_v1_effect.h"
+#include "gpu/bind_group_builder.h"
+#include "gpu/effect.h"
+#include "gpu/pipeline_builder.h"
+#include "gpu/post_process_helper.h"
+#include "gpu/sampler_cache.h"
+#include "gpu/shader_composer.h"
+#include "gpu/shaders.h"
+
+// Create custom pipeline with 5 bindings (includes original texture)
+static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device,
+ WGPUTextureFormat format,
+ const char* shader_code) {
+ WGPUBindGroupLayout bgl =
+ BindGroupLayoutBuilder()
+ .sampler(0, WGPUShaderStage_Fragment)
+ .texture(1, WGPUShaderStage_Fragment)
+ .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment)
+ .uniform(3, WGPUShaderStage_Fragment)
+ .texture(4, WGPUShaderStage_Fragment)
+ .build(device);
+
+ WGPURenderPipeline pipeline = RenderPipelineBuilder(device)
+ .shader(shader_code)
+ .bind_group_layout(bgl)
+ .format(format)
+ .build();
+
+ wgpuBindGroupLayoutRelease(bgl);
+ return pipeline;
+}
+
+CNNv1Effect::CNNv1Effect(const GpuContext& ctx)
+ : PostProcessEffect(ctx), layer_index_(0), total_layers_(1),
+ blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr),
+ bind_group_(nullptr) {
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
+}
+
+CNNv1Effect::CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params)
+ : PostProcessEffect(ctx), layer_index_(params.layer_index),
+ total_layers_(params.total_layers), blend_amount_(params.blend_amount),
+ input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) {
+ pipeline_ =
+ create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl);
+}
+
+void CNNv1Effect::init(MainSequence* demo) {
+ PostProcessEffect::init(demo);
+ demo_ = demo;
+ params_buffer_.init(ctx_.device);
+
+ // Register auxiliary texture for layer 0 (width_/height_ set by resize())
+ if (layer_index_ == 0) {
+ demo_->register_auxiliary_texture("captured_frame", width_, height_);
+ }
+
+ // Initialize uniforms BEFORE any bind group creation
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ CNNv1LayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+}
+
+void CNNv1Effect::resize(int width, int height) {
+ if (width == width_ && height == height_)
+ return;
+
+ PostProcessEffect::resize(width, height);
+
+ // Only layer 0 owns the captured_frame texture
+ if (layer_index_ == 0 && demo_) {
+ demo_->resize_auxiliary_texture("captured_frame", width, height);
+ }
+}
+
+void CNNv1Effect::render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) {
+ if (!bind_group_) {
+ fprintf(stderr, "CNN render: no bind_group\n");
+ return;
+ }
+
+ float effective_blend = blend_amount_;
+ if (beat_modulated_) {
+ effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_;
+ }
+
+ CNNv1LayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}};
+ params_buffer_.update(ctx_.queue, params);
+
+ wgpuRenderPassEncoderSetPipeline(pass, pipeline_);
+ wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr);
+ wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0);
+}
+
+void CNNv1Effect::update_bind_group(WGPUTextureView input_view) {
+ input_view_ = input_view;
+
+ // Update common uniforms (CRITICAL for UV calculation!)
+ uniforms_.update(ctx_.queue, get_common_uniforms());
+
+ // All layers: get captured frame (original input from layer 0)
+ if (demo_) {
+ original_view_ = demo_->get_auxiliary_view("captured_frame");
+ }
+
+ // Create bind group with original texture
+ if (bind_group_)
+ wgpuBindGroupRelease(bind_group_);
+
+ WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0);
+ // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior
+ WGPUSampler sampler =
+ SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp());
+
+ bind_group_ =
+ BindGroupBuilder()
+ .sampler(0, sampler)
+ .texture(1, input_view_)
+ .buffer(2, uniforms_.get().buffer, uniforms_.get().size)
+ .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size)
+ .texture(4, original_view_ ? original_view_ : input_view_)
+ .build(ctx_.device, bgl);
+}
diff --git a/cnn_v1/src/cnn_v1_effect.h b/cnn_v1/src/cnn_v1_effect.h
new file mode 100644
index 0000000..e820275
--- /dev/null
+++ b/cnn_v1/src/cnn_v1_effect.h
@@ -0,0 +1,53 @@
+// CNN post-processing effect header
+// Multi-layer neural network stylization
+
+#pragma once
+#include "gpu/effect.h"
+#include "gpu/uniform_helper.h"
+
+struct CNNv1LayerParams {
+ int layer_index;
+ float blend_amount; // Blend: mix(input, output, blend_amount)
+ float _pad[2];
+};
+static_assert(sizeof(CNNv1LayerParams) == 16);
+
+struct CNNv1EffectParams {
+ int layer_index = 0; // Which layer to render (0-based)
+ int total_layers = 1; // Total number of layers in the CNN
+ float blend_amount = 1.0f; // Final blend with original input
+};
+
+class CNNv1Effect : public PostProcessEffect {
+ public:
+ explicit CNNv1Effect(const GpuContext& ctx);
+ explicit CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params);
+
+ void init(MainSequence* demo) override;
+ void resize(int width, int height) override;
+ void render(WGPURenderPassEncoder pass,
+ const CommonPostProcessUniforms& uniforms) override;
+ void update_bind_group(WGPUTextureView input_view) override;
+
+ // Layer 0 needs framebuffer capture for original input
+ bool needs_framebuffer_capture() const override {
+ return layer_index_ == 0;
+ }
+
+ void set_beat_modulation(bool enabled, float scale = 1.0f) {
+ beat_modulated_ = enabled;
+ beat_scale_ = scale;
+ }
+
+ private:
+ int layer_index_;
+ int total_layers_;
+ float blend_amount_;
+ bool beat_modulated_ = false;
+ float beat_scale_ = 1.0f;
+ WGPUTextureView input_view_;
+ WGPUTextureView original_view_;
+ UniformBuffer<CNNv1LayerParams> params_buffer_;
+ WGPUBindGroup bind_group_;
+ MainSequence* demo_ = nullptr;
+};
diff --git a/cnn_v1/training/train_cnn.py b/cnn_v1/training/train_cnn.py
new file mode 100755
index 0000000..4171dcb
--- /dev/null
+++ b/cnn_v1/training/train_cnn.py
@@ -0,0 +1,943 @@
+#!/usr/bin/env python3
+"""
+CNN Training Script for Image-to-Image Transformation
+
+Trains a convolutional neural network on multiple input/target image pairs.
+
+Usage:
+ # Training
+ python3 train_cnn.py --input input_dir/ --target target_dir/ [options]
+
+ # Inference (generate ground truth)
+ python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png
+
+Example:
+ python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100
+ python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms
+from PIL import Image
+import numpy as np
+import cv2
+import os
+import sys
+import argparse
+import glob
+
+
+class ImagePairDataset(Dataset):
+ """Dataset for loading matching input/target image pairs"""
+
+ def __init__(self, input_dir, target_dir, transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.transform = transform
+
+ # Find all images in input directory
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ # Try to find matching target with same name but any supported extension
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} matching image pairs")
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+ def __getitem__(self, idx):
+ input_path, target_path = self.image_pairs[idx]
+
+ # Load RGBD input (4 channels: RGB + Depth)
+ input_img = Image.open(input_path).convert('RGBA')
+ target_img = Image.open(target_path).convert('RGB')
+
+ if self.transform:
+ input_img = self.transform(input_img)
+ target_img = self.transform(target_img)
+
+ return input_img, target_img
+
+
+class PatchDataset(Dataset):
+ """Dataset for extracting salient patches from image pairs"""
+
+ def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64,
+ detector='harris', transform=None):
+ self.input_dir = input_dir
+ self.target_dir = target_dir
+ self.patch_size = patch_size
+ self.patches_per_image = patches_per_image
+ self.detector = detector
+ self.transform = transform
+
+ # Find all image pairs
+ input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
+ self.image_pairs = []
+
+ for pattern in input_patterns:
+ input_files = glob.glob(os.path.join(input_dir, pattern))
+ for input_path in input_files:
+ filename = os.path.basename(input_path)
+ target_path = None
+ for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']:
+ base_name = os.path.splitext(filename)[0]
+ candidate = os.path.join(target_dir, f"{base_name}.{ext}")
+ if os.path.exists(candidate):
+ target_path = candidate
+ break
+
+ if target_path:
+ self.image_pairs.append((input_path, target_path))
+
+ if not self.image_pairs:
+ raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}")
+
+ print(f"Found {len(self.image_pairs)} image pairs")
+ print(f"Extracting {patches_per_image} patches per image using {detector} detector")
+ print(f"Total patches: {len(self.image_pairs) * patches_per_image}")
+
+ def __len__(self):
+ return len(self.image_pairs) * self.patches_per_image
+
+ def _detect_salient_points(self, img_array):
+ """Detect salient points using specified detector"""
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
+ h, w = gray.shape
+ half_patch = self.patch_size // 2
+
+ if self.detector == 'harris':
+ # Harris corner detection
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch)
+ elif self.detector == 'fast':
+ # FAST feature detection
+ fast = cv2.FastFeatureDetector_create(threshold=20)
+ keypoints = fast.detect(gray, None)
+ corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ elif self.detector == 'shi-tomasi':
+ # Shi-Tomasi corner detection (goodFeaturesToTrack with different params)
+ corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2,
+ qualityLevel=0.01, minDistance=half_patch,
+ useHarrisDetector=False)
+ elif self.detector == 'gradient':
+ # High-gradient regions
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
+
+ # Find top gradient locations
+ threshold = np.percentile(gradient_mag, 95)
+ y_coords, x_coords = np.where(gradient_mag > threshold)
+
+ if len(x_coords) > self.patches_per_image * 2:
+ indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False)
+ x_coords = x_coords[indices]
+ y_coords = y_coords[indices]
+
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None
+ else:
+ raise ValueError(f"Unknown detector: {self.detector}")
+
+ # Fallback to random if no corners found
+ if corners is None or len(corners) == 0:
+ x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image)
+ y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image)
+ corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)])
+ corners = corners.reshape(-1, 1, 2)
+
+ # Filter valid corners (within bounds)
+ valid_corners = []
+ for corner in corners:
+ x, y = int(corner[0][0]), int(corner[0][1])
+ if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch:
+ valid_corners.append((x, y))
+ if len(valid_corners) >= self.patches_per_image:
+ break
+
+ # Fill with random if not enough
+ while len(valid_corners) < self.patches_per_image:
+ x = np.random.randint(half_patch, w - half_patch)
+ y = np.random.randint(half_patch, h - half_patch)
+ valid_corners.append((x, y))
+
+ return valid_corners
+
+ def __getitem__(self, idx):
+ img_idx = idx // self.patches_per_image
+ patch_idx = idx % self.patches_per_image
+
+ input_path, target_path = self.image_pairs[img_idx]
+
+ # Load images
+ input_img = Image.open(input_path).convert('RGBA')
+ target_img = Image.open(target_path).convert('RGB')
+
+ # Detect salient points (use input image for detection)
+ input_array = np.array(input_img)[:, :, :3] # Use RGB for detection
+ corners = self._detect_salient_points(input_array)
+
+ # Extract patch at specified index
+ x, y = corners[patch_idx]
+ half_patch = self.patch_size // 2
+
+ # Crop patches
+ input_patch = input_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+ target_patch = target_img.crop((x - half_patch, y - half_patch,
+ x + half_patch, y + half_patch))
+
+ if self.transform:
+ input_patch = self.transform(input_patch)
+ target_patch = self.transform(target_patch)
+
+ return input_patch, target_patch
+
+
+class SimpleCNN(nn.Module):
+ """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray)
+
+ Internally computes grayscale, expands to 3-channel RGB output.
+ """
+
+ def __init__(self, num_layers=1, kernel_sizes=None):
+ super(SimpleCNN, self).__init__()
+
+ if kernel_sizes is None:
+ kernel_sizes = [3] * num_layers
+
+ assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers"
+
+ self.kernel_sizes = kernel_sizes
+ self.layers = nn.ModuleList()
+
+ for i, kernel_size in enumerate(kernel_sizes):
+ padding = kernel_size // 2
+ if i < num_layers - 1:
+ # Inner layers: 7→4 (RGBD output)
+ self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True))
+ else:
+ # Final layer: 7→1 (grayscale output)
+ self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True))
+
+ def forward(self, x, return_intermediates=False):
+ # x: [B,4,H,W] - RGBD input (D = 1/z)
+ B, C, H, W = x.shape
+
+ intermediates = [] if return_intermediates else None
+
+ # Normalize RGBD to [-1,1]
+ x_norm = (x - 0.5) * 2.0
+
+ # Compute normalized coordinates [-1,1]
+ y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W)
+ x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W)
+
+ # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1]
+ gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1]
+ gray = (gray - 0.5) * 2.0 # [-1,1]
+
+ # Layer 0
+ layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W]
+ out = self.layers[0](layer0_input) # [B,4,H,W]
+ out = torch.tanh(out) # [-1,1]
+ if return_intermediates:
+ intermediates.append(out.clone())
+
+ # Inner layers
+ for i in range(1, len(self.layers)-1):
+ layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[i](layer_input)
+ out = torch.tanh(out)
+ if return_intermediates:
+ intermediates.append(out.clone())
+
+ # Final layer (grayscale→RGB)
+ final_input = torch.cat([out, x_coords, y_coords, gray], dim=1)
+ out = self.layers[-1](final_input) # [B,1,H,W] grayscale
+ out = torch.sigmoid(out) # Map to [0,1] with smooth gradients
+ final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB
+
+ if return_intermediates:
+ return final_out, intermediates
+ return final_out
+
+
+def generate_layer_shader(output_path, num_layers, kernel_sizes):
+ """Generate cnn_layer.wgsl with proper layer switches"""
+
+ with open(output_path, 'w') as f:
+ f.write("// CNN layer shader - uses modular convolution snippets\n")
+ f.write("// Supports multi-pass rendering with residual connections\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+ f.write("@group(0) @binding(0) var smplr: sampler;\n")
+ f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n")
+ f.write("#include \"common_uniforms\"\n")
+ f.write("#include \"cnn_activation\"\n")
+
+ # Include necessary conv functions
+ conv_sizes = set(kernel_sizes)
+ for ks in sorted(conv_sizes):
+ f.write(f"#include \"cnn_conv{ks}x{ks}\"\n")
+ f.write("#include \"cnn_weights_generated\"\n\n")
+
+ f.write("struct CNNLayerParams {\n")
+ f.write(" layer_index: i32,\n")
+ f.write(" blend_amount: f32,\n")
+ f.write(" _pad: vec2<f32>,\n")
+ f.write("};\n\n")
+ f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n")
+ f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n")
+ f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n")
+ f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n")
+ f.write(" var pos = array<vec2<f32>, 3>(\n")
+ f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n")
+ f.write(" );\n")
+ f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n")
+ f.write("}\n\n")
+ f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n")
+ f.write(" // Match PyTorch linspace\n")
+ f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n")
+ f.write(" let original_raw = textureSample(original_input, smplr, uv);\n")
+ f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n")
+ f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n")
+ f.write(" var result = vec4<f32>(0.0);\n\n")
+
+ # Generate layer switches
+ for layer_idx in range(num_layers):
+ is_final = layer_idx == num_layers - 1
+ ks = kernel_sizes[layer_idx]
+ conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1"
+
+ if layer_idx == 0:
+ conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src"
+ f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n")
+ f.write(f" if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result);\n")
+ f.write(f" }}\n")
+ elif not is_final:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
+ f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n")
+ f.write(f" }}\n")
+ else:
+ f.write(f" else if (params.layer_index == {layer_idx}) {{\n")
+ f.write(f" let sum = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n")
+ f.write(f" let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation\n")
+ f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n")
+ f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n")
+ f.write(f" }}\n")
+
+ f.write(" return result; // [-1,1]\n")
+ f.write("}\n")
+
+
+def export_weights_to_wgsl(model, output_path, kernel_sizes):
+ """Export trained weights to WGSL format (vec4-optimized)"""
+
+ with open(output_path, 'w') as f:
+ f.write("// Auto-generated CNN weights (vec4-optimized)\n")
+ f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n")
+
+ for i, layer in enumerate(model.layers):
+ weights = layer.weight.data.cpu().numpy()
+ bias = layer.bias.data.cpu().numpy()
+ out_ch, in_ch, kh, kw = weights.shape
+ num_positions = kh * kw
+
+ is_final = (i == len(model.layers) - 1)
+
+ if is_final:
+ # Final layer: 7→1, structure: array<vec4<f32>, 18> (9 pos × 2 vec4)
+ # Input: [rgba, uv_gray_1] → 2 vec4s per position
+ f.write(f"const weights_layer{i}: array<vec4<f32>, {num_positions * 2}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ # First vec4: [w0, w1, w2, w3] (rgba)
+ v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)]
+ # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
+ v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)]
+ v1.append(f"{bias[0] / num_positions:.6f}")
+ f.write(f" vec4<f32>({', '.join(v0)}),\n")
+ f.write(f" vec4<f32>({', '.join(v1)})")
+ f.write(",\n" if pos < num_positions-1 else "\n")
+ f.write(");\n\n")
+ else:
+ # Inner layers: 7→4, structure: array<vec4<f32>, 72> (36 entries × 2 vec4)
+ # Each filter: 2 vec4s for [rgba][uv_gray_1] inputs
+ num_vec4s = num_positions * 4 * 2
+ f.write(f"const weights_layer{i}: array<vec4<f32>, {num_vec4s}> = array(\n")
+ for pos in range(num_positions):
+ row, col = pos // kw, pos % kw
+ for out_c in range(4):
+ # First vec4: [w0, w1, w2, w3] (rgba)
+ v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)]
+ # Second vec4: [w4, w5, w6, bias] (uv, gray, 1)
+ v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)]
+ v1.append(f"{bias[out_c] / num_positions:.6f}")
+ idx = (pos * 4 + out_c) * 2
+ f.write(f" vec4<f32>({', '.join(v0)}),\n")
+ f.write(f" vec4<f32>({', '.join(v1)})")
+ f.write(",\n" if idx < num_vec4s-2 else "\n")
+ f.write(");\n\n")
+
+
+def generate_conv_base_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n")
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+
+ # Accumulate
+ f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
+ f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
+ f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
+ f.write(f" pos += 8;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def generate_conv_src_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Source layer: 7→4 channels (vec4-optimized)\n")
+ f.write(f"// Normalizes [0,1] input to [-1,1] internally\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n")
+ f.write(f") -> vec4<f32> {{\n")
+ f.write(f" let step = 1.0 / resolution;\n\n")
+
+ # Normalize center pixel for gray channel
+ f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n")
+ f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+
+ f.write(f" var sum = vec4<f32>(0.0);\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n")
+
+ # Accumulate with dot products (unrolled)
+ f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n")
+ f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n")
+ f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n")
+ f.write(f" pos += 8;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def generate_conv_final_function(kernel_size, output_path):
+ """Generate cnn_conv{K}x{K}_7to1() function for final layer (vec4-optimized)"""
+
+ k = kernel_size
+ num_positions = k * k
+ radius = k // 2
+
+ with open(output_path, 'a') as f:
+ f.write(f"\n// Final layer: 7→1 channel (vec4-optimized)\n")
+ f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n")
+ f.write(f"// Returns raw sum (activation applied at call site)\n")
+ f.write(f"fn cnn_conv{k}x{k}_7to1(\n")
+ f.write(f" tex: texture_2d<f32>,\n")
+ f.write(f" samp: sampler,\n")
+ f.write(f" uv: vec2<f32>,\n")
+ f.write(f" resolution: vec2<f32>,\n")
+ f.write(f" gray: f32,\n")
+ f.write(f" weights: array<vec4<f32>, {num_positions * 2}>\n")
+ f.write(f") -> f32 {{\n")
+ f.write(f" let step = 1.0 / resolution;\n")
+ f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n")
+ f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n")
+ f.write(f" var sum = 0.0;\n")
+ f.write(f" var pos = 0;\n\n")
+
+ # Convolution loop
+ f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n")
+ f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n")
+ f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n")
+ f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n")
+
+ # Accumulate with dot products
+ f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n")
+ f.write(f" pos += 2;\n")
+ f.write(f" }}\n")
+ f.write(f" }}\n\n")
+
+ f.write(f" return sum;\n")
+ f.write(f"}}\n")
+
+
+def train(args):
+ """Main training loop"""
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Prepare dataset
+ if args.patch_size:
+ # Patch-based training (preserves natural scale)
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ dataset = PatchDataset(args.input, args.target,
+ patch_size=args.patch_size,
+ patches_per_image=args.patches_per_image,
+ detector=args.detector,
+ transform=transform)
+ else:
+ # Full-image training (resize mode)
+ transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ ])
+ dataset = ImagePairDataset(args.input, args.target, transform=transform)
+
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
+
+ # Parse kernel sizes
+ kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
+ if len(kernel_sizes) == 1 and args.layers > 1:
+ kernel_sizes = kernel_sizes * args.layers
+
+ # Create model
+ model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device)
+
+ # Loss and optimizer
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ # Resume from checkpoint
+ start_epoch = 0
+ if args.resume:
+ if os.path.exists(args.resume):
+ print(f"Loading checkpoint from {args.resume}...")
+ checkpoint = torch.load(args.resume, map_location=device)
+ model.load_state_dict(checkpoint['model_state'])
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
+ start_epoch = checkpoint['epoch'] + 1
+ print(f"Resumed from epoch {start_epoch}")
+ else:
+ print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
+
+ # Compute valid center region (exclude conv padding borders)
+ num_layers = args.layers
+ border = num_layers # Each 3x3 layer needs 1px, accumulates across layers
+
+ # Early stopping setup
+ loss_history = []
+ early_stop_triggered = False
+
+ # Training loop
+ print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
+ print(f"Computing loss on center region only (excluding {border}px border)")
+ if args.early_stop_patience > 0:
+ print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}")
+
+ for epoch in range(start_epoch, args.epochs):
+ epoch_loss = 0.0
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
+ inputs, targets = inputs.to(device), targets.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(inputs)
+
+ # Only compute loss on center pixels with valid neighborhoods
+ if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border:
+ outputs_center = outputs[:, :, border:-border, border:-border]
+ targets_center = targets[:, :, border:-border, border:-border]
+ loss = criterion(outputs_center, targets_center)
+ else:
+ loss = criterion(outputs, targets)
+
+ loss.backward()
+ optimizer.step()
+
+ epoch_loss += loss.item()
+
+ avg_loss = epoch_loss / len(dataloader)
+ if (epoch + 1) % 10 == 0:
+ print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+
+ # Early stopping check
+ if args.early_stop_patience > 0:
+ loss_history.append(avg_loss)
+ if len(loss_history) >= args.early_stop_patience:
+ oldest_loss = loss_history[-args.early_stop_patience]
+ loss_change = abs(avg_loss - oldest_loss)
+ if loss_change < args.early_stop_eps:
+ print(f"Early stopping triggered at epoch {epoch+1}")
+ print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}")
+ early_stop_triggered = True
+ break
+
+ # Save checkpoint
+ if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
+ checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
+ torch.save({
+ 'epoch': epoch,
+ 'model_state': model.state_dict(),
+ 'optimizer_state': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.layers
+ }, checkpoint_path)
+ print(f"Saved checkpoint to {checkpoint_path}")
+
+ # Export weights and shader
+ output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"\nExporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, args.layers, kernel_sizes)
+
+ # Generate conv shader files for all kernel sizes
+ for ks in set(kernel_sizes):
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
+ if not os.path.exists(conv_path):
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
+ continue
+
+ # File exists, check for missing functions
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer if missing
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant to {conv_path}")
+
+ print("Training complete!")
+
+
+def export_from_checkpoint(checkpoint_path, output_path=None):
+ """Export WGSL files from checkpoint without training"""
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint file '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ kernel_sizes = checkpoint['kernel_sizes']
+ num_layers = checkpoint['num_layers']
+
+ # Recreate model
+ model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes)
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Export weights
+ output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
+ print(f"Exporting weights to {output_path}...")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_weights_to_wgsl(model, output_path, kernel_sizes)
+
+ # Generate layer shader
+ shader_dir = os.path.dirname(output_path)
+ shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl')
+ print(f"Generating layer shader to {shader_path}...")
+ generate_layer_shader(shader_path, num_layers, kernel_sizes)
+
+ # Generate conv shader files for all kernel sizes
+ for ks in set(kernel_sizes):
+ conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl')
+
+ # Create file with header if it doesn't exist
+ if not os.path.exists(conv_path):
+ print(f"Creating {conv_path}...")
+ with open(conv_path, 'w') as f:
+ f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n")
+ generate_conv_base_function(ks, conv_path)
+ generate_conv_src_function(ks, conv_path)
+ generate_conv_final_function(ks, conv_path)
+ print(f"Generated complete {conv_path}")
+ continue
+
+ # File exists, check for missing functions
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate base 7to4 if missing
+ if f"cnn_conv{ks}x{ks}_7to4" not in content:
+ generate_conv_base_function(ks, conv_path)
+ print(f"Added base 7to4 to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate _src variant if missing
+ if f"cnn_conv{ks}x{ks}_7to4_src" not in content:
+ generate_conv_src_function(ks, conv_path)
+ print(f"Added _src variant to {conv_path}")
+ with open(conv_path, 'r') as f:
+ content = f.read()
+
+ # Generate 7to1 final layer if missing
+ if f"cnn_conv{ks}x{ks}_7to1" not in content:
+ generate_conv_final_function(ks, conv_path)
+ print(f"Added 7to1 variant to {conv_path}")
+
+ print("Export complete!")
+
+
+def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False):
+ """Run sliding-window inference to match WGSL shader behavior
+
+ Outputs RGBA PNG (RGB from model + alpha from input).
+ """
+
+ if not os.path.exists(checkpoint_path):
+ print(f"Error: Checkpoint '{checkpoint_path}' not found")
+ sys.exit(1)
+
+ if not os.path.exists(input_path):
+ print(f"Error: Input image '{input_path}' not found")
+ sys.exit(1)
+
+ print(f"Loading checkpoint from {checkpoint_path}...")
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # Reconstruct model
+ model = SimpleCNN(
+ num_layers=checkpoint['num_layers'],
+ kernel_sizes=checkpoint['kernel_sizes']
+ )
+ model.load_state_dict(checkpoint['model_state'])
+
+ # Debug: Zero out all weights and biases
+ if zero_weights:
+ print("DEBUG: Zeroing out all weights and biases")
+ for layer in model.layers:
+ with torch.no_grad():
+ layer.weight.zero_()
+ layer.bias.zero_()
+
+ model.eval()
+
+ # Load image
+ print(f"Loading input image: {input_path}")
+ img = Image.open(input_path).convert('RGBA')
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W]
+ W, H = img.size
+
+ # Process full image with sliding window (matches WGSL shader)
+ print(f"Processing full image ({W}×{H}) with sliding window...")
+ with torch.no_grad():
+ if save_intermediates:
+ output_tensor, intermediates = model(img_tensor, return_intermediates=True)
+ else:
+ output_tensor = model(img_tensor) # [1,3,H,W] RGB
+
+ # Convert to numpy and append alpha
+ output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB
+ alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input
+ output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ output_u8 = (output_rgba * 255).astype(np.uint8)
+ print("First 8 pixels (RGBA hex):")
+ for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])):
+ y, x = i // output_u8.shape[1], i % output_u8.shape[1]
+ r, g, b, a = output_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+
+ # Save final output as RGBA
+ print(f"Saving output to: {output_path}")
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
+ output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA')
+ output_img.save(output_path)
+
+ # Save intermediates if requested
+ if save_intermediates:
+ os.makedirs(save_intermediates, exist_ok=True)
+ print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}")
+ for layer_idx, layer_tensor in enumerate(intermediates):
+ # Convert [-1,1] to [0,1] for visualization
+ layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5
+ layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8)
+
+ # Debug: print first 8 pixels as hex
+ if debug_hex:
+ print(f"Layer {layer_idx} first 8 pixels (RGBA hex):")
+ for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])):
+ y, x = i // layer_u8.shape[1], i % layer_u8.shape[1]
+ if layer_u8.shape[2] == 4:
+ r, g, b, a = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}")
+ else:
+ r, g, b = layer_u8[y, x]
+ print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}")
+
+ # Save all 4 channels for intermediate layers
+ if layer_data.shape[2] == 4:
+ layer_img = Image.fromarray(layer_u8, mode='RGBA')
+ else:
+ layer_img = Image.fromarray(layer_u8)
+ layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png')
+ layer_img.save(layer_path)
+ print(f" Saved layer {layer_idx} to {layer_path}")
+
+ print("Done!")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation')
+ parser.add_argument('--input', help='Input image directory (training) or single image (inference)')
+ parser.add_argument('--target', help='Target image directory')
+ parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)')
+ parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)')
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)')
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
+ parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
+ parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)')
+ parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
+ parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
+ parser.add_argument('--resume', help='Resume from checkpoint file')
+ parser.add_argument('--export-only', help='Export WGSL from checkpoint without training')
+ parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)')
+ parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)')
+ parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)')
+ parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'],
+ help='Salient point detector for patch extraction (default: harris)')
+ parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)')
+ parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)')
+ parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)')
+ parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)')
+ parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)')
+
+ args = parser.parse_args()
+
+ # Inference mode
+ if args.infer:
+ checkpoint = args.export_only
+ if not checkpoint:
+ print("Error: --infer requires --export-only <checkpoint>")
+ sys.exit(1)
+ output_path = args.output or 'inference_output.png'
+ patch_size = args.patch_size or 32
+ infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex)
+ return
+
+ # Export-only mode
+ if args.export_only:
+ export_from_checkpoint(args.export_only, args.output)
+ return
+
+ # Validate directories for training
+ if not args.input or not args.target:
+ print("Error: --input and --target required for training (or use --export-only)")
+ sys.exit(1)
+
+ if not os.path.isdir(args.input):
+ print(f"Error: Input directory '{args.input}' does not exist")
+ sys.exit(1)
+
+ if not os.path.isdir(args.target):
+ print(f"Error: Target directory '{args.target}' does not exist")
+ sys.exit(1)
+
+ train(args)
+
+
+if __name__ == "__main__":
+ main()