diff options
Diffstat (limited to 'cnn_v1')
| -rw-r--r-- | cnn_v1/README.md | 64 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN.md | 79 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_BIAS_FIX_2026-02.md | 85 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_DEBUG.md | 43 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md | 189 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md | 136 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_TEST_TOOL.md | 244 | ||||
| -rw-r--r-- | cnn_v1/docs/CNN_V1_EFFECT.md | 400 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_activation.wgsl | 18 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_conv1x1.wgsl | 100 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_conv3x3.wgsl | 100 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_conv5x5.wgsl | 101 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_conv7x7.wgsl | 53 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_layer.wgsl | 55 | ||||
| -rw-r--r-- | cnn_v1/shaders/cnn_weights_generated.wgsl | 302 | ||||
| -rw-r--r-- | cnn_v1/src/cnn_v1_effect.cc | 129 | ||||
| -rw-r--r-- | cnn_v1/src/cnn_v1_effect.h | 53 | ||||
| -rwxr-xr-x | cnn_v1/training/train_cnn.py | 943 |
18 files changed, 3094 insertions, 0 deletions
diff --git a/cnn_v1/README.md b/cnn_v1/README.md new file mode 100644 index 0000000..052f22a --- /dev/null +++ b/cnn_v1/README.md @@ -0,0 +1,64 @@ +# CNN v1: Original Post-Processing Neural Network + +**Architecture:** 3-layer convolution, generated shader weights +**Status:** Active (used in timeline), legacy architecture + +## Overview + +Original CNN implementation with per-layer WGSL shaders. Supports multiple kernel sizes (1×1, 3×3, 5×5, 7×7) with generated weight arrays. + +**For new work, use CNN v2** (`cnn_v2/`) which provides: +- Storage buffer architecture (~3.2 KB vs generated WGSL) +- 7D static features (RGBD + UV + sin + bias) +- Sigmoid activation with stable training +- Dynamic layer configuration + +## Quick Reference + +**Training:** +```bash +./cnn_v1/training/train_cnn.py --input training/input --target training/output \ + --layers 3 --kernel_sizes 3,5,3 --epochs 5000 +``` + +**Integration:** +- **C++:** `cnn_v1/src/cnn_effect.{h,cc}` +- **Assets:** `workspaces/main/assets.txt` (lines 40-46) +- **Timeline:** `workspaces/main/timeline.seq` (CNNEffect) + +## Documentation + +- [CNN.md](docs/CNN.md) - Architecture overview +- [CNN_V1_EFFECT.md](docs/CNN_V1_EFFECT.md) - Implementation details +- [CNN_TEST_TOOL.md](docs/CNN_TEST_TOOL.md) - Testing guide +- [CNN_DEBUG.md](docs/CNN_DEBUG.md) - Debugging notes + +## Directory Structure + +``` +cnn_v1/ +├── README.md # This file +├── src/ +│ ├── cnn_effect.h # Effect header +│ └── cnn_effect.cc # Effect implementation +├── shaders/ # WGSL shaders (7 files) +├── training/ # Python training script +└── docs/ # Documentation (7 markdown files) +``` + +## Differences from CNN v2 + +| Feature | CNN v1 | CNN v2 | +|---------|--------|--------| +| Architecture | Generated WGSL weights | Storage buffer weights | +| Input Features | 4D (RGBA/prev layer) | 12D (4D + 8D static) | +| Activation | ReLU | Sigmoid + ReLU | +| Size | ~Variable (WGSL gen) | ~3.2 KB (binary) | +| Training | Full-image | Patch-based (default) | +| Layer Config | Compile-time | Runtime (dynamic) | + +## Migration Notes + +CNN v1 remains in the timeline for historical validation. For new effects or experiments, use CNN v2's enhanced feature set and compact binary format. + +See `cnn_v2/docs/CNN_V2.md` for CNN v2 architecture details. diff --git a/cnn_v1/docs/CNN.md b/cnn_v1/docs/CNN.md new file mode 100644 index 0000000..5d9a667 --- /dev/null +++ b/cnn_v1/docs/CNN.md @@ -0,0 +1,79 @@ +# Convolutional Neural Net Shader (CNN) post-processing + +**Status:** ✅ Foundation implemented (single-layer, expandable to multi-pass) + +## Idea + +Have the input 3d scene be processed by a multi-layer CNN trained on the side. +Input: some rendered scene. +Output: 'stylized' scene with CNN post-processing. + +**See `CNN_V1_EFFECT.md` for implementation details, usage, and API reference.** + +## Shader implementation + +### input / output + +Need 1 texture buffer per CNN layer. +Input (r,g,b,1/z) for layer 0 (render 3d scene), or output from layer N-1 for layer N. +output: (r,g,b, alpha). Don't need the 1/z information (can be fetched from input) + +### size of one layer + +Notation: +S: the number of input samples from layer N-1. +Example: 3x3 input -> S = 3x3 = 9. + +Each S samples is 4 values (r,g,b, w=1/z). + +Each sample is processed by a mat4 matrix. 4 input => 4 output. + +Weight matrix = S x mat4 + +Final bias: 4 values. + +WGSL code example: See file CNN.shader + +### Layers + +we need 3 or 4 layer ? +Several different shaders for each layer. +Ping-pong for input/output texture buffer between each layers? + +## Implementation Status + +**Completed:** +- ✅ Modular WGSL shader architecture (6 snippet files) +- ✅ CNNEffect C++ class (single-layer rendering) +- ✅ ShaderComposer integration (#include resolution) +- ✅ Asset registration (7 new shader assets) +- ✅ Test coverage (test_demo_effects.cc) +- ✅ Placeholder identity weights for testing + +**Size:** ~3-4 KB shader code + ~2-4 KB weights = **5-8 KB total** + +**Pending:** +- ⏳ Training script (`scripts/train_cnn.py`) to generate real weights +- ⏳ Multi-layer rendering with ping-pong textures +- ⏳ Weight quantization for size optimization + +--- + +## Training (To Be Implemented) + +The layer weight/bias data are hard-coded in the shaders. +Training workflow: + +1. Prepare image pairs (before: raw render, after: target style) +2. Run `python scripts/train_cnn.py --input scene.png --target stylized.png` +3. Script generates `cnn_weights_generated.wgsl` +4. Rebuild: `cmake --build build -j4` + +**Reference:** File `CNN.py` contains training example (needs adaptation). + +Need a repository of reference image pairs (before/after) for training and validation. +Each input image is randomly sampled into 3×3 patch of (r,g,b,1/z) input samples. +And trained to match the (r,g,b,a) output. + +Training generates the .wgsl code for layers' shaders. + diff --git a/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md b/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md new file mode 100644 index 0000000..26db8eb --- /dev/null +++ b/cnn_v1/docs/CNN_BIAS_FIX_2026-02.md @@ -0,0 +1,85 @@ +# CNN Bias Accumulation Fix (2026-02-11) + +## Problem +Bias was being added multiple times in shader convolution loops (once per kernel position), causing mismatch between PyTorch training and WGSL inference. + +## Root Cause +**Location**: `training/train_cnn.py:381, 398` + +When exporting weights to WGSL, bias was replicated for every kernel position. The shader loops through positions doing: +```wgsl +sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); // in1.w = 1.0 +``` + +For 3×3 kernel (9 positions), bias added 9×. For 5×5, added 25×. + +## Fix +Divide bias by `num_positions` during export: +```python +# Final layer (7→1) +v1.append(f"{bias[0] / num_positions:.6f}") + +# Inner layers (7→4) +v1.append(f"{bias[out_c] / num_positions:.6f}") +``` + +Shader accumulates bias × num_positions = original bias (correct). + +--- + +## Additional Improvements + +### 1. RGBA Output Support +**train_cnn.py**: Now saves 4-channel RGBA PNG preserving alpha from input: +```python +alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() +output_rgba = np.concatenate([output, alpha], axis=2) +Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') +``` + +Intermediate layers also save RGBA if 4-channel. + +### 2. Debug Hex Output +**Both tools** support `--debug-hex` to print first 8 pixels as hex: +```bash +./training/train_cnn.py --infer input.png --export-only checkpoint.pth --debug-hex +./build/cnn_test input.png output.png --debug-hex +``` + +Output format: `[0] 0xRRGGBBAA` for pixel-level comparison. + +### 3. Cleanup +Removed sRGB/linear_png debug code from `cnn_test.cc` (simplified PNG saving). + +--- + +## Files Modified +- `training/train_cnn.py`: Bias fix, RGBA output, --debug-hex +- `tools/cnn_test.cc`: --debug-hex, remove linear_png +- `workspaces/main/shaders/cnn/cnn_weights_generated.wgsl`: Regenerated with fixed bias + +## Testing +```bash +# Train with fixed export +./training/train_cnn.py --input training/input/ --target training/output/ \ + --layers 3 --kernel_sizes 3,3,3 --epochs 5000 + +# Generate ground truth +./training/train_cnn.py --infer input.png --export-only checkpoint.pth \ + --output ground_truth.png --debug-hex + +# Run GPU tool +./build/cnn_test input.png tool_output.png --debug-hex + +# Compare hex output for first 8 pixels +``` + +--- + +## Status +✅ Bias accumulation bug fixed +✅ RGBA output with alpha preservation +✅ Debug hex comparison tool +✅ Weights regenerated + +Commit: `8ff8c56` diff --git a/cnn_v1/docs/CNN_DEBUG.md b/cnn_v1/docs/CNN_DEBUG.md new file mode 100644 index 0000000..ba220a0 --- /dev/null +++ b/cnn_v1/docs/CNN_DEBUG.md @@ -0,0 +1,43 @@ +# CNN Effect Black Screen Bug - Resolution (2026-02) + +## Problem +CNN post-processing effect showed black screen when activated at 11.50s, despite scene rendering correctly before CNN started. + +## Root Causes + +### Bug 1: Framebuffer Capture Timing +**Location**: `src/gpu/effect.cc` +**Issue**: Capture ran INSIDE post-effect loop after ping-pong buffer swaps. CNN layers 1+ captured wrong buffer (output being written to, not scene). +**Fix**: Moved capture before loop starts (lines 308-346). Capture now copies `framebuffer_a` to `captured_frame` auxiliary texture ONCE before any post-effects run. + +### Bug 2: Missing Uniforms Update ⚠️ CRITICAL +**Location**: `src/effects/cnn_effect.cc` +**Issue**: `CNNEffect::update_bind_group()` never updated `uniforms_` buffer. `uniforms.resolution` uninitialized (0,0 or garbage) → UV calculation `p.xy / uniforms.resolution` produced NaN → all texture samples black. +**Fix**: Added uniforms update before bind group creation (lines 132-142): +```cpp +const CommonPostProcessUniforms u = { + .resolution = {(float)width_, (float)height_}, + .aspect_ratio = (float)width_ / (float)height_, + .time = 0.0f, + .beat = 0.0f, + .audio_intensity = 0.0f, +}; +uniforms_.update(ctx_.queue, u); +``` + +## Key Lessons + +1. **All post-process effects MUST update `uniforms_` buffer** - Required for UV calculations and shader parameters +2. **Framebuffer capture timing is critical** - Must happen before post-chain ping-pong starts +3. **Uninitialized uniforms cause silent failures** - Produces black output without validation errors +4. **Post-effects must render or chain breaks** - `loadOp=Load` preserves previous (black) content if no draw call executes + +## Files Modified +- `src/gpu/effect.cc`: Lines 308-346 (capture timing) +- `src/effects/cnn_effect.cc`: Lines 132-142 (uniforms update) + +## Verification +Test: `demo64k --seek 11.5` +- ✅ Scene visible with RotatingCube +- ✅ CNN stylization applied +- ✅ All 3 layers process with correct original texture reference diff --git a/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md b/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md new file mode 100644 index 0000000..8664157 --- /dev/null +++ b/cnn_v1/docs/CNN_FLATTEN_ANALYSIS.md @@ -0,0 +1,189 @@ +# CNN Shader Flatten Mode - Technical Analysis + +**Status:** Analysis complete - flatten mode NOT RECOMMENDED + +**Date:** February 2026 + +--- + +## Context + +Current CNN architecture uses **3 sequential render passes** (linear chaining): +- **Layer 0:** 5×5 conv (7→4 channels) → framebuffer +- **Layer 1:** 3×3 conv (7→4 channels) → reads L0 output, writes framebuffer +- **Layer 2:** 3×3 conv (7→1 channel) → reads L1 output, blends with original + +Proposed **"flatten mode"**: Collapse all layers into **single shader pass** using intermediate arrays, eliminating framebuffer read/write between layers. + +--- + +## Current Architecture + +**Shader Structure:** +- 1 pipeline with layer branching (`layer_index` uniform) +- 5 bindings: sampler, input texture, uniforms, layer params, original capture +- Total shader size: ~8 KB (snippets + weights) + +**Performance Profile:** +- 3 render pass dispatches +- 2 framebuffer writes + reads between layers +- Memory bandwidth: ~2× framebuffer size per layer +- Register pressure: Low (per-layer isolation) + +**Weight Buffer:** 290 vec4s (4.6 KB) - already unified + +--- + +## Flatten Approaches Evaluated + +### Option A: Full Flatten (All 3 Layers) + +**Cascading Receptive Field:** + +To compute final output at position (x, y): +- Layer 2 needs 3×3 neighborhood of Layer 1 outputs +- Each Layer 1 output needs 3×3 neighborhood of Layer 0 outputs +- Each Layer 0 output needs 5×5 neighborhood of input samples + +**Effective input sampling:** 9×9 pixels (vs current 5×5 max) + +**Intermediate Storage (per thread/pixel):** +``` +Layer 0 outputs: 5×5 positions × 4 channels = 100 floats +Layer 1 outputs: 3×3 positions × 4 channels = 36 floats + TOTAL = 136 floats (544 bytes) +``` + +**GPU Register Pressure:** +- Modern GPUs: 32-64 KB registers per SM, shared across warps +- 544 bytes/thread → max 64 threads/SM (**low occupancy**) +- Current multi-pass: ~4-8 bytes/thread (high occupancy) + +**Pros:** +- 1 dispatch vs 3 (reduce CPU overhead) +- Zero framebuffer bandwidth between layers + +**Cons:** +- **Severe register pressure** (10-20× increase) +- Reduced occupancy → potential performance loss +- Complex shader (harder debug, larger binary) +- 9×9 input sampling + +**Assessment:** ❌ **Not Recommended** +Register cost outweighs bandwidth savings. + +--- + +### Option B: Partial Flatten (Layers 1 + 2) + +Keep Layer 0 separate, flatten only Layers 1 and 2. + +**Pass Structure:** +1. **Pass 1:** Layer 0 (5×5 conv) → framebuffer +2. **Pass 2 (flattened):** Compute Layer 1 + Layer 2 in single shader + +**Intermediate Storage:** +``` +Layer 0 samples: 3×3 × 4 = 36 floats (read once) +Layer 1 outputs: 3×3 × 4 = 36 floats (computed) + TOTAL = 72 floats (288 bytes) +``` + +**Receptive Field:** 5×5 Layer 0 samples required for 3×3 Layer 1 outputs + +**Pros:** +- 2 passes vs 3 (33% reduction) +- 1 framebuffer write saved +- More manageable register usage + +**Cons:** +- Still significant register pressure (288 bytes vs ~8 bytes baseline) +- Medium complexity increase +- Layer 0 (heaviest kernel) still separate + +**Assessment:** ⚠️ **Marginal Benefit** +Saves 1 pass but register cost still high. + +--- + +### Option C: Keep Current Multi-Pass ✅ + +**Rationale:** +- Current architecture well-suited to GPU design (high throughput via parallelism) +- Minimal register usage → high occupancy → hides memory latency +- Framebuffer bandwidth cost < register pressure cost +- Clean separation aids debugging/iteration +- Modular (easy to add/remove layers) + +**Alternative Optimizations (if bandwidth critical):** +1. Merge passes via render pass load/store ops (Vulkan subpasses) +2. Reduce intermediate channel count (4→3 or 2) +3. Hybrid: Compute shaders + workgroup shared memory +4. Layer pruning (2-layer vs 3-layer quality comparison) + +--- + +## Recommendation + +**✅ Keep current multi-pass architecture** + +### Decision Matrix + +| Factor | Multi-Pass | Partial Flatten | Full Flatten | +|--------|-----------|----------------|--------------| +| Register pressure | ✅ Low | ⚠️ High | ❌ Extreme | +| Occupancy | ✅ High | ⚠️ Medium | ❌ Low | +| Memory bandwidth | ⚠️ Medium | ✅ Lower | ✅ Lowest | +| Shader complexity | ✅ Simple | ⚠️ Medium | ❌ High | +| Debuggability | ✅ Easy | ⚠️ Harder | ❌ Very hard | +| Binary size | ✅ Small | ⚠️ Larger | ⚠️ Largest | + +**Modern GPU Architecture Favors:** +- High parallelism (many small threads) over complex threads +- Hiding latency via occupancy over minimizing operations +- Memory bandwidth via caching, not elimination + +--- + +## Alternative: Compute Shader + Shared Memory + +**If bandwidth becomes critical:** +- Use compute shader with workgroup shared memory +- Load tile + halos into shared memory (9×9 input samples) +- Compute all 3 layers for tile interior (avoids redundant sampling) +- Requires explicit synchronization (`workgroupBarrier`) + +**Trade-offs:** +- ✅ Low register pressure + low bandwidth +- ❌ Compute pipeline complexity (no render pass integration) +- ❌ Tile edge handling +- ❌ Larger code size + +--- + +## Conclusion + +Current 3-pass architecture is **appropriate for demo64k**: +- Size-efficient (modular shaders) +- Performance adequate (bandwidth not bottleneck) +- Maintainable (clean layer isolation) + +**Flatten mode not recommended** unless profiling reveals specific bandwidth constraint. + +### Size Optimization Alternatives (Better ROI) + +If size optimization critical, focus on: +1. **Weight quantization:** 4.6 KB → ~2 KB (8-bit or 4-bit quantization) +2. **Kernel size reduction:** 5×5 → 3×3 for Layer 0 (200 vec4s → 72 vec4s) +3. **Channel reduction:** 7 inputs → 4 inputs (remove UV/grayscale channels) + +These yield better size/performance than shader architecture changes. + +--- + +## References + +- `CNN_V1_EFFECT.md` - CNN implementation details +- `CNN.md` - High-level CNN design +- `../src/cnn_effect.cc` - Current implementation +- `workspaces/main/shaders/cnn_*.wgsl` - Shader snippets diff --git a/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md b/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md new file mode 100644 index 0000000..3439f2c --- /dev/null +++ b/cnn_v1/docs/CNN_RGBD_GRAYSCALE_SUMMARY.md @@ -0,0 +1,136 @@ +# CNN RGBD→Grayscale Architecture Implementation + +## Summary + +Implemented CNN architecture upgrade: RGBD input → grayscale output with 7-channel augmented input. + +## Changes Made + +### Architecture + +**Input:** RGBD (4 channels: RGB + inverse depth D=1/z) +**Output:** Grayscale (1 channel) +**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1] + +**Layer Configuration:** +- Inner layers (0..N-2): Conv2d(7→4) - output RGBD with tanh activation +- Final layer (N-1): Conv2d(7→1) - output grayscale, no activation + +### Input Normalization (all to [-1,1]) + +- **RGBD:** `(rgbd - 0.5) * 2` +- **UV coords:** `(uv - 0.5) * 2` +- **Grayscale:** `dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722))` (computed once, passed as parameter) + +**Rationale:** Zero-centered inputs for tanh activation, better gradient flow. + +### Modified Files + +**Training (`/Users/skal/demo/training/train_cnn.py`):** +1. Removed `CoordConv2d` class +2. Updated `SimpleCNN`: + - Inner layers: `Conv2d(7, 4)` - RGBD output + - Final layer: `Conv2d(7, 1)` - grayscale output +3. Updated `forward()`: + - Normalize RGBD/coords/gray to [-1,1] + - Concatenate 7-channel input for each layer + - Apply tanh (inner) or none (final) + - Denormalize final output +4. Updated `export_weights_to_wgsl()`: + - Inner: `array<array<f32, 8>, 36>` (9 pos × 4 ch × 8 values) + - Final: `array<array<f32, 8>, 9>` (9 pos × 8 values) +5. Updated `generate_layer_shader()`: + - Use `cnn_conv3x3_7to4` for inner layers + - Use `cnn_conv3x3_7to1` for final layer + - Denormalize outputs from [-1,1] to [0,1] +6. Updated `ImagePairDataset`: + - Load RGBA input (was RGB) + +**Shaders (`/Users/skal/demo/workspaces/main/shaders/cnn/cnn_conv3x3.wgsl`):** +1. Added `cnn_conv3x3_7to4()`: + - 7-channel input: [RGBD, uv_x, uv_y, gray] (gray passed as parameter) + - 4-channel output: RGBD + - Weights: `array<array<f32, 8>, 36>` +2. Added `cnn_conv3x3_7to1()`: + - 7-channel input: [RGBD, uv_x, uv_y, gray] (gray passed as parameter) + - 1-channel output: grayscale + - Weights: `array<array<f32, 8>, 9>` +3. Optimized: gray computed once in caller using `dot()`, not per-function + +**Documentation (`/Users/skal/demo/doc/CNN_EFFECT.md`):** +1. Updated architecture section with RGBD→grayscale pipeline +2. Updated training data requirements (RGBA input) +3. Updated weight storage format + +### No C++ Changes + +CNNLayerParams and bind groups remain unchanged. + +## Data Flow + +1. Layer 0 captures original RGBD to `captured_frame` +2. Each layer: + - Samples previous layer output (RGBD in [0,1]) + - Normalizes RGBD to [-1,1] + - Computes gray once using `dot()` (fs_main level) + - Normalizes UV coords to [-1,1] (inside conv functions) + - Concatenates 7-channel input + - Applies convolution with layer-specific weights + - Outputs RGBD (inner) or grayscale (final) in [-1,1] + - Applies tanh (inner only) + - Denormalizes to [0,1] for texture storage + - Blends with original + +## Next Steps + +1. **Prepare RGBD training data:** + - Input: RGBA images (RGB + depth in alpha) + - Target: Grayscale stylized output + +2. **Train network:** + ```bash + python3 training/train_cnn.py \ + --input training/input \ + --target training/output \ + --layers 3 \ + --epochs 1000 + ``` + +3. **Verify generated shaders:** + - Check `cnn_weights_generated.wgsl` structure + - Check `cnn_layer.wgsl` uses new conv functions + +4. **Test in demo:** + ```bash + cmake --build build -j4 + ./build/demo64k + ``` + +## Design Rationale + +**Why [-1,1] normalization?** +- Centered inputs for tanh (operates best around 0) +- Better gradient flow +- Standard ML practice for normalized data + +**Why RGBD throughout vs RGB?** +- Depth information propagates through network +- Enables depth-aware stylization +- Consistent 4-channel processing + +**Why 7-channel input?** +- Coordinates: position-dependent effects (vignettes) +- Grayscale: luminance-aware processing +- RGBD: full color+depth information +- Enables richer feature learning + +## Testing Checklist + +- [ ] Train network with RGBD input data +- [ ] Verify `cnn_weights_generated.wgsl` structure +- [ ] Verify `cnn_layer.wgsl` uses `7to4`/`7to1` functions +- [ ] Build demo without errors +- [ ] Visual test: inner layers show RGBD evolution +- [ ] Visual test: final layer produces grayscale +- [ ] Visual test: blending works correctly +- [ ] Compare quality with previous RGB→RGB architecture diff --git a/cnn_v1/docs/CNN_TEST_TOOL.md b/cnn_v1/docs/CNN_TEST_TOOL.md new file mode 100644 index 0000000..4307894 --- /dev/null +++ b/cnn_v1/docs/CNN_TEST_TOOL.md @@ -0,0 +1,244 @@ +# CNN Shader Testing Tool + +Standalone tool for validating trained CNN shaders with GPU-to-CPU readback. Supports both CNN v1 (render pipeline) and v2 (compute, storage buffer). + +--- + +## Purpose + +- Validate trained weights against ground truth +- Debug CNN layer behavior in isolation +- Generate test outputs for training workflow +- Match Python training script's inference mode + +--- + +## Architecture + +**Two implementations:** + +1. **CNN v1** (render pipeline, texture atlas weights) + - 3 fixed layers + - RGBA16Float intermediates + - BGRA8Unorm final output + +2. **CNN v2** (compute shaders, storage buffer weights) + - Dynamic layer count from binary + - 7D static features (RGBD + UV + sin + bias) + - RGBA32Uint packed f16 intermediates + - Storage buffer: ~3-5 KB weights + +**Core GPU utility:** `src/gpu/texture_readback.{h,cc}` +- Synchronous texture-to-CPU readback +- Supports RGBA16Float, RGBA32Uint, BGRA8Unorm +- Protected with STRIP_ALL (0 bytes in release) + +--- + +## Usage + +```bash +cnn_test input.png output.png [OPTIONS] + +OPTIONS: + --cnn-version N CNN version: 1 (default) or 2 (ignored with --weights) + --weights PATH Load weights from .bin (forces CNN v2, overrides layer config) + --blend F Final blend amount (0.0-1.0, default: 1.0) + --format ppm|png Output format (default: png) + --layers N Number of CNN layers (1-10, v1 only, default: 3, ignored with --weights) + --save-intermediates DIR Save intermediate layers to directory + --debug-hex Print first 8 pixels as hex (debug) + --help Show usage +``` + +**Examples:** +```bash +# CNN v1 (render pipeline, 3 layers) +./build/cnn_test input.png output.png --cnn-version 1 + +# CNN v2 (compute, storage buffer, uses asset system weights) +./build/cnn_test input.png output.png --cnn-version 2 + +# CNN v2 with runtime weight loading (loads layer config from .bin) +./build/cnn_test input.png output.png --weights checkpoints/checkpoint_epoch_100.pth.bin + +# 50% blend with original (v2) +./build/cnn_test input.png output.png --cnn-version 2 --blend 0.5 + +# Debug hex dump +./build/cnn_test input.png output.png --cnn-version 2 --debug-hex +``` + +**Important:** When using `--weights`, the layer count and kernel sizes are read from the binary file header, overriding any `--layers` or `--cnn-version` arguments. + +--- + +## Implementation Details + +### Core Readback Utility + +**File:** `src/gpu/texture_readback.{h,cc}` + +**Function:** +```cpp +std::vector<uint8_t> read_texture_pixels( + WGPUInstance instance, + WGPUDevice device, + WGPUTexture texture, + int width, + int height); +``` + +**Features:** +- Returns BGRA8 format (4 bytes per pixel) +- Synchronous blocking operation +- Cross-platform async callback handling (Win32 vs Native API) +- Automatic staging buffer creation and cleanup + +**Refactored OffscreenRenderTarget:** +```cpp +std::vector<uint8_t> OffscreenRenderTarget::read_pixels() { +#if !defined(STRIP_ALL) + return read_texture_pixels(instance_, device_, texture_, width_, height_); +#else + return std::vector<uint8_t>(); +#endif +} +``` + +### CNN v1 Pipeline (Render) + +**Fixed 3-layer architecture:** +- Ping-pong RGBA16Float textures +- CNNLayerParams (binding 3): layer_index, blend_amount +- Shader composer resolves #include directives + +### CNN v2 Pipeline (Compute) + +**Dynamic layer architecture:** +1. **Static features compute:** Generate 7D features (RGBD + UV + sin + bias) +2. **Layer computes:** N layers from binary weights (3-5 typically) + - Storage buffer weights (read-only) + - RGBA32Uint packed f16 textures (ping-pong) + - CNNv2LayerParams: kernel_size, channels, weight_offset, blend +3. **Readback:** RGBA32Uint → f16 decode → u8 clamp + +**Binary format:** Header (20B) + layer info (20B×N) + f16 weights + +**Weight Loading:** +- **Without `--weights`:** Loads from asset system (`ASSET_WEIGHTS_CNN_V2`) +- **With `--weights PATH`:** Loads from external `.bin` file (e.g., checkpoint exports) + - Layer count and kernel sizes parsed from binary header + - Overrides any `--layers` or `--cnn-version` arguments + - Enables runtime testing of training checkpoints without rebuild + +--- + +## Build Integration + +**CMakeLists.txt:** + +1. Added `src/gpu/texture_readback.cc` to GPU_SOURCES (both sections) +2. Tool target: +```cmake +add_executable(cnn_test + tools/cnn_test.cc + src/tests/common/webgpu_test_fixture.cc + src/tests/common/offscreen_render_target.cc + ${PLATFORM_SOURCES} + ${GEN_DEMO_CC}) + +target_link_libraries(cnn_test PRIVATE + gpu util procedural ${DEMO_LIBS}) + +add_dependencies(cnn_test generate_demo_assets) + +target_compile_definitions(cnn_test PRIVATE + STB_IMAGE_IMPLEMENTATION + STB_IMAGE_WRITE_IMPLEMENTATION) +``` + +**Build:** +```bash +cmake -S . -B build -DDEMO_BUILD_TOOLS=ON +cmake --build build -j4 +``` + +--- + +## Validation Workflow (CNN v2) + +### 1. Train and Export +```bash +# Train and export weights +./scripts/train_cnn_v2_full.sh --epochs 200 --batch-size 16 +``` + +### 2. Tool Inference +```bash +# Run tool with v2 +./build/cnn_test training/input/img_000.png output.png --cnn-version 2 +``` + +### 3. Visual Comparison +Compare output.png with training/target_X/img_000.png + +--- + +## Status + +**CNN v1:** Builds and runs, produces incorrect output (all white). Use CNNEffect in demo for visual validation. + +**CNN v2:** ⚠️ Partially functional. Readback works but output differs from HTML validation tool. +- Loads binary weights from `workspaces/main/weights/cnn_v2_weights.bin` +- Matches CNNv2Effect architecture +- **Known Issue:** Visual output differs from `tools/cnn_v2_test/index.html` despite matching shader code +- Root cause under investigation (weight indexing? texture sampling? activation clamping?) +- Use HTML tool (`tools/cnn_v2_test/index.html`) for accurate validation + +--- + +## Technical Notes (Readback Fix) + +**Original Bug:** Buffer mapping returned `WGPUMapAsyncStatus_Unknown` (status=5) + +**Root Cause:** Callback mode mismatch +- Used `WGPUCallbackMode_WaitAnyOnly` (fires only during `wgpuInstanceWaitAny`) +- Called `wgpuInstanceProcessEvents` in wait loop (wrong API for this mode) +- Callback never fired → timeout → empty buffer + +**Fix Applied:** +1. Changed callback mode to `WGPUCallbackMode_AllowProcessEvents` +2. Replaced `wgpuInstanceProcessEvents` with `wgpuDevicePoll(device, true, nullptr)` +3. Added pre-mapping device poll to ensure copy completes + +**Relevant Code:** `src/gpu/texture_readback.cc` lines 97-110 + +**Reference:** WebGPU spec - Asynchronous Operations, Callback Modes + +--- + +## Limitations + +- **CNN v1:** Produces incorrect output, use for debugging only +- **Single image:** Batch processing requires shell loop +- **No real-time preview:** Offline processing only +- **PNG input:** stb_image (JPEG/PNG/BMP/TGA also supported) + +--- + +## Technical Notes + +**CNN v2 f16 decoding:** +- RGBA32Uint texture stores 8×f16 as 4×u32 +- Custom decoder: extract u16, decode f16→f32, clamp [0,1]→u8 +- Handles denormals, infinity, NaN + +**Cross-platform:** +- macOS, Linux (native WebGPU) +- Windows (mingw-w64 cross-compile) + +**Size impact:** +- Debug/STRIP_ALL=OFF: compiled +- STRIP_ALL=ON: 0 bytes (compiled out) +- FINAL_STRIP=ON: tool not built diff --git a/cnn_v1/docs/CNN_V1_EFFECT.md b/cnn_v1/docs/CNN_V1_EFFECT.md new file mode 100644 index 0000000..40f095e --- /dev/null +++ b/cnn_v1/docs/CNN_V1_EFFECT.md @@ -0,0 +1,400 @@ +# CNN Post-Processing Effect + +Neural network-based stylization for rendered scenes. + +--- + +## Overview + +Trainable convolutional neural network layers for artistic stylization (painterly, sketch, cel-shaded effects) with minimal runtime overhead. + +**Key Features:** +- Position-aware layer 0 (coordinate input for vignetting, edge effects) +- Multi-layer convolutions (3×3, 5×5, 7×7 kernels) with automatic chaining +- Original input available to all layers via framebuffer capture +- Configurable final blend with original scene +- Modular WGSL shader architecture +- Hardcoded weights (trained offline via PyTorch) +- ~5-8 KB binary footprint + +--- + +## Architecture + +### RGBD → Grayscale Pipeline + +**Input:** RGBD (RGB + inverse depth D=1/z) +**Output:** Grayscale (1 channel) +**Layer Input:** 7 channels = [RGBD, UV coords, grayscale] all normalized to [-1,1] + +**Architecture:** +- **Inner layers (0..N-2):** Conv2d(7→4) - output RGBD +- **Final layer (N-1):** Conv2d(7→1) - output grayscale + +```wgsl +// Inner layers: 7→4 (RGBD output, vec4-optimized) +fn cnn_conv3x3_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, # Grayscale [-1,1] + weights: array<vec4<f32>, 72> # 9 pos × 4 ch × 2 vec4 (8 floats per filter) +) -> vec4<f32> + +// Final layer: 7→1 (grayscale output, vec4-optimized) +fn cnn_conv3x3_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 18> # 9 pos × 2 vec4 (8 floats per filter) +) -> f32 +``` + +**Input normalization:** +- **fs_main** normalizes textures once: `(tex - 0.5) * 2` → [-1,1] +- **Conv functions** normalize UV coords: `(uv - 0.5) * 2` → [-1,1] +- **Grayscale** computed once in fs_main using dot product: `dot(original.rgb, vec3(0.2126, 0.7152, 0.0722))` +- **Inter-layer data** stays in [-1,1] (no denormalization) +- **Final output** denormalized for display: `(result + 1.0) * 0.5` → [0,1] + +**Activation:** tanh for inner layers (output stays [-1,1]), none for final layer + +### Multi-Layer Architecture + +CNNEffect supports multi-layer networks via automatic effect chaining: + +1. **Timeline specifies total layers**: `CNNEffect layers=3 blend=0.7` +2. **Compiler expands to chain**: 3 separate CNNEffect instances (layer 0→1→2) +3. **Framebuffer capture**: Layer 0 captures original input to `"captured_frame"` +4. **Original input binding**: All layers access original via `@binding(4)` +5. **Final blend**: Last layer blends result with original: `mix(original, result, 0.7)` + +**Framebuffer Capture API:** +- `Effect::needs_framebuffer_capture()` - effect requests pre-capture +- MainSequence automatically blits input → `"captured_frame"` auxiliary texture +- Generic mechanism usable by any effect + +### File Structure + +``` +src/effects/ + cnn_effect.h/cc # CNNEffect class + framebuffer capture + +workspaces/main/shaders/cnn/ + cnn_activation.wgsl # tanh, ReLU, sigmoid, leaky_relu + cnn_conv3x3.wgsl # 3×3 convolution (standard + coord-aware) + cnn_conv5x5.wgsl # 5×5 convolution (standard + coord-aware) + cnn_conv7x7.wgsl # 7×7 convolution (standard + coord-aware) + cnn_weights_generated.wgsl # Weight arrays (auto-generated by train_cnn.py) + cnn_layer.wgsl # Main shader with layer switches (auto-generated by train_cnn.py) +``` + +--- + +## Training Workflow + +### 1. Prepare Training Data + +Input/target image pairs: +``` +training/input/img_000.png # RGBA (RGB + alpha) +training/output/img_000.png # Grayscale target +``` + +**Note:** Alpha channel can be depth (1/z) or constant (255). Network learns from RGB primarily. + +### 2. Train Network + +**Patch-based (Recommended)** - Preserves natural pixel scale: +```bash +python3 training/train_cnn.py \ + --input training/input --target training/output \ + --patch-size 32 --patches-per-image 64 --detector harris \ + --layers 3 --kernel-sizes 3,5,3 \ + --epochs 5000 --batch-size 16 --checkpoint-every 1000 +``` + +**Detectors:** `harris` (corners), `fast` (features), `shi-tomasi` (corners), `gradient` (edges) + +**Full-image (Legacy)** - Resizes to 256×256: +```bash +python3 training/train_cnn.py \ + --input training/input --target training/output \ + --layers 3 --kernel-sizes 3,5,3 \ + --epochs 10000 --batch-size 8 --checkpoint-every 1000 +``` + +**Auto-generates:** +- `cnn_weights_generated.wgsl` - Weight arrays +- `cnn_layer.wgsl` - Layer shader + +### 3. Export & Validate + +```bash +# Export shaders +./training/train_cnn.py --export-only checkpoints/checkpoint_epoch_5000.pth + +# Generate ground truth +./training/train_cnn.py --infer input.png \ + --export-only checkpoints/checkpoint_epoch_5000.pth --output ground_truth.png +``` + +### 4. Rebuild Demo + +```bash +cmake --build build -j4 && ./build/demo64k +``` + +--- + +## Usage + +### C++ Integration + +**Single layer (manual):** +```cpp +#include "effects/cnn_effect.h" + +CNNEffectParams p; +p.layer_index = 0; +p.total_layers = 1; +p.blend_amount = 1.0f; +auto cnn = std::make_shared<CNNEffect>(ctx, p); +timeline.add_effect(cnn, start_time, end_time); +``` + +**Multi-layer (automatic via timeline compiler):** + +Use timeline syntax - `seq_compiler` expands to multiple instances. + +### Timeline Examples + +**Single-layer CNN (full stylization):** +``` +SEQUENCE 10.0 0 + EFFECT + Hybrid3DEffect 0.00 5.00 + EFFECT + CNNEffect 0.50 5.00 layers=1 +``` + +**Multi-layer CNN with blend:** +``` +SEQUENCE 10.0 0 + EFFECT + Hybrid3DEffect 0.00 5.00 + EFFECT + CNNEffect 0.50 5.00 layers=3 blend=0.7 +``` + +Expands to: +```cpp +// Layer 0 (captures original, blend=1.0) +{ + CNNEffectParams p; + p.layer_index = 0; + p.total_layers = 3; + p.blend_amount = 1.0f; + seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 1); +} +// Layer 1 (blend=1.0) +{ + CNNEffectParams p; + p.layer_index = 1; + p.total_layers = 3; + p.blend_amount = 1.0f; + seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 2); +} +// Layer 2 (final blend=0.7) +{ + CNNEffectParams p; + p.layer_index = 2; + p.total_layers = 3; + p.blend_amount = 0.7f; + seq->add_effect(std::make_shared<CNNEffect>(ctx, p), 0.50f, 5.00f, 3); +} +``` + +--- + +## Shader Structure + +**Bindings:** +```wgsl +@group(0) @binding(0) var smplr: sampler; +@group(0) @binding(1) var txt: texture_2d<f32>; // Current layer input +@group(0) @binding(2) var<uniform> uniforms: CommonUniforms; +@group(0) @binding(3) var<uniform> params: CNNLayerParams; +@group(0) @binding(4) var original_input: texture_2d<f32>; // Layer 0 input (captured) +``` + +**Fragment shader logic:** +```wgsl +@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> { + let uv = p.xy / uniforms.resolution; + let original_raw = textureSample(original_input, smplr, uv); + let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1] + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); + var result = vec4<f32>(0.0); + + if (params.layer_index == 0) { + result = cnn_conv3x3_7to4_src(txt, smplr, uv, uniforms.resolution, + weights_layer0); + result = cnn_tanh(result); + } + else if (params.layer_index == 1) { + result = cnn_conv5x5_7to4(txt, smplr, uv, uniforms.resolution, + gray, weights_layer1); + result = cnn_tanh(result); + } + // ... other layers + + // Blend with ORIGINAL input (not previous layer) + return mix(original_raw, result, params.blend_amount); +} +``` + +**Weight Storage (vec4-optimized):** + +**Inner layers (7→4 RGBD output):** +```wgsl +// Structure: array<vec4<f32>, 72> +// 9 pos × 4 ch × 2 vec4 (8 floats per filter: [rgba][uv,gray,1]) +const weights_layer0: array<vec4<f32>, 72> = array( + vec4<f32>(w0_r, w0_g, w0_b, w0_d), // pos0_ch0 (rgba weights) + vec4<f32>(w0_u, w0_v, w0_gray, bias0), // pos0_ch0 (uv, gray, bias) + vec4<f32>(w1_r, w1_g, w1_b, w1_d), // pos0_ch1 (rgba weights) + vec4<f32>(w1_u, w1_v, w1_gray, bias1), // pos0_ch1 (uv, gray, bias) + // ... 68 more vec4s +); +``` + +**Final layer (7→1 grayscale output):** +```wgsl +// Structure: array<vec4<f32>, 18> +// 9 pos × 2 vec4 (8 floats per filter: [rgba][uv,gray,1]) +const weights_layerN: array<vec4<f32>, 18> = array( + vec4<f32>(w0_r, w0_g, w0_b, w0_d), // pos0 (rgba weights) + vec4<f32>(w0_u, w0_v, w0_gray, bias0), // pos0 (uv, gray, bias) + // ... 16 more vec4s +); +``` + +**Optimization:** Bias integrated as 4th component via `vec4(uv, gray, 1.0)` input. Two dot4 operations replace 8 scalar MADs. + +--- + +## Size Budget + +| Component | Size | Notes | +|-----------|------|-------| +| Activation functions | ~200 B | 4 functions | +| Conv3x3 (standard + coord) | ~500 B | Both variants | +| Conv5x5 (standard + coord) | ~700 B | Both variants | +| Conv7x7 (standard + coord) | ~900 B | Both variants | +| Main shader | ~800 B | Layer composition | +| C++ implementation | ~300 B | Effect class | +| **Coord weights** | **+32 B** | Per-layer overhead (layer 0 only) | +| **RGBA weights** | **2-6 KB** | Depends on depth/kernel sizes | +| **Total** | **5-9 KB** | Acceptable for 64k | + +**Optimization strategies:** +- Quantize weights (float32 → int8) +- Prune near-zero weights +- Use separable convolutions + +--- + +## Testing + +```bash +./build/test_demo_effects # CNN construction/shader tests +./build/demo64k # Visual test +``` + +--- + +## Blend Parameter Behavior + +**blend_amount** controls final compositing with original: +- `blend=0.0`: Pure original (no CNN effect) +- `blend=0.5`: 50% original + 50% CNN +- `blend=1.0`: Pure CNN output (full stylization) + +**Important:** Blend uses captured layer 0 input, not previous layer output. + +**Example use cases:** +- `blend=1.0`: Full stylization (default) +- `blend=0.7`: Subtle effect preserving original details +- `blend=0.3`: Light artistic touch + +## Troubleshooting + +**Shader compilation fails:** +- Check `cnn_weights_generated.wgsl` syntax +- Verify snippets registered in `shaders.cc::InitShaderComposer()` +- Ensure `cnn_layer.wgsl` has 5 bindings (including `original_input`) + +**Black/corrupted output:** +- Weights untrained (identity placeholder) +- Check `captured_frame` auxiliary texture is registered +- Verify layer priorities in timeline are sequential + +**Wrong blend result:** +- Ensure layer 0 has `needs_framebuffer_capture() == true` +- Check MainSequence framebuffer capture logic +- Verify `original_input` binding is populated + +**Training loss not decreasing:** +- Lower learning rate (`--learning-rate 0.0001`) +- More epochs (`--epochs 1000`) +- Check input/target image alignment + +--- + +## Vec4 Optimization + +**Architecture:** Weights stored as vec4 pairs for SIMD efficiency. + +**Input representation:** +```wgsl +let rgbd = textureSample(...); // vec4: [r, g, b, d] +let in1 = vec4<f32>(uv_norm, gray, 1.0); // vec4: [u, v, gray, 1.0] +``` + +**Weight indexing:** +```wgsl +var pos = 0; // Direct weight array index +for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + // Unrolled channel loop (4 output channels) + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; // 4 channels × 2 vec4s per channel + } +} +``` + +**Benefits:** +- **SIMD-native:** GPU executes `dot(vec4, vec4)` as single instruction (4 parallel MADs) +- **Memory bandwidth:** 2 vec4 loads vs 8 scalar loads (better cache alignment) +- **Bias integration:** Free via `[..., 1.0]` component (no separate add) +- **Code simplicity:** Eliminates inner loop, direct indexing with `pos` +- **Performance:** 2-3× GPU throughput improvement over scalar version + +**Weight layout per filter (8 floats):** +- vec4[0]: [w_r, w_g, w_b, w_d] (rgba input weights) +- vec4[1]: [w_u, w_v, w_gray, bias] (uv, grayscale, bias) + +**3×3 kernel sizes:** +- Inner layer (7→4): 72 vec4s (9 pos × 4 ch × 2 vec4 = 2304 bytes) +- Final layer (7→1): 18 vec4s (9 pos × 1 ch × 2 vec4 = 288 bytes) + +--- + +## References + +- **Training Script:** `training/train_cnn.py` +- **Shader Composition:** `doc/SEQUENCE.md` +- **Effect System:** `src/gpu/effect.h` diff --git a/cnn_v1/shaders/cnn_activation.wgsl b/cnn_v1/shaders/cnn_activation.wgsl new file mode 100644 index 0000000..4fe771e --- /dev/null +++ b/cnn_v1/shaders/cnn_activation.wgsl @@ -0,0 +1,18 @@ +// CNN activation functions +// 4 functions: tanh, ReLU, sigmoid, leaky_relu + +fn cnn_tanh(x: vec4<f32>) -> vec4<f32> { + return tanh(x); +} + +fn cnn_relu(x: vec4<f32>) -> vec4<f32> { + return max(vec4<f32>(0.0), x); +} + +fn cnn_sigmoid(x: vec4<f32>) -> vec4<f32> { + return 1.0 / (1.0 + exp(-x)); +} + +fn cnn_leaky_relu(x: vec4<f32>, alpha: f32) -> vec4<f32> { + return max(alpha * x, x); +} diff --git a/cnn_v1/shaders/cnn_conv1x1.wgsl b/cnn_v1/shaders/cnn_conv1x1.wgsl new file mode 100644 index 0000000..f77cfa8 --- /dev/null +++ b/cnn_v1/shaders/cnn_conv1x1.wgsl @@ -0,0 +1,100 @@ +// 1x1 convolution (vec4-optimized) + +// Inner layers: 7→4 channels (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +fn cnn_conv1x1_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 8> +) -> vec4<f32> { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -0; dy <= 0; dy++) { + for (var dx = -0; dx <= 0; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} + +// Source layer: 7→4 channels (vec4-optimized) +// Normalizes [0,1] input to [-1,1] internally +fn cnn_conv1x1_7to4_src( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + weights: array<vec4<f32>, 8> +) -> vec4<f32> { + let step = 1.0 / resolution; + + var original = (textureSample(tex, samp, uv) - 0.5) * 2.0; + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -0; dy <= 0; dy++) { + for (var dx = -0; dx <= 0; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + var rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0; + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} + +// Final layer: 7→1 channel (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +// Returns raw sum (activation applied at call site) +fn cnn_conv1x1_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 2> +) -> f32 { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = 0.0; + var pos = 0; + + for (var dy = -0; dy <= 0; dy++) { + for (var dx = -0; dx <= 0; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + + sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); + pos += 2; + } + } + + return sum; +} diff --git a/cnn_v1/shaders/cnn_conv3x3.wgsl b/cnn_v1/shaders/cnn_conv3x3.wgsl new file mode 100644 index 0000000..f7d11b1 --- /dev/null +++ b/cnn_v1/shaders/cnn_conv3x3.wgsl @@ -0,0 +1,100 @@ +// 3x3 convolution (vec4-optimized) + +// Inner layers: 7→4 channels (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +fn cnn_conv3x3_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 72> +) -> vec4<f32> { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} + +// Source layer: 7→4 channels (vec4-optimized) +// Normalizes [0,1] input to [-1,1] internally +fn cnn_conv3x3_7to4_src( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + weights: array<vec4<f32>, 72> +) -> vec4<f32> { + let step = 1.0 / resolution; + + let original = (textureSample(tex, samp, uv) - 0.5) * 2.0; + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0; + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} + +// Final layer: 7→1 channel (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +// Returns raw sum (activation applied at call site) +fn cnn_conv3x3_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 18> +) -> f32 { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = 0.0; + var pos = 0; + + for (var dy = -1; dy <= 1; dy++) { + for (var dx = -1; dx <= 1; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + + sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); + pos += 2; + } + } + + return sum; +} diff --git a/cnn_v1/shaders/cnn_conv5x5.wgsl b/cnn_v1/shaders/cnn_conv5x5.wgsl new file mode 100644 index 0000000..9328d75 --- /dev/null +++ b/cnn_v1/shaders/cnn_conv5x5.wgsl @@ -0,0 +1,101 @@ +// 5×5 variant for 7→4 channels (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<vec4<f32>, 200> (25 pos × 4 ch × 2 vec4) +fn cnn_conv5x5_7to4( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 200> +) -> vec4<f32> { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -2; dy <= 2; dy++) { + for (var dx = -2; dx <= 2; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} + +// 5×5 variant for 7→1 channel (vec4-optimized) +// Assumes 'tex' is already normalized to [-1,1] +// UV coordinates remain in [0,1] and are normalized internally +// weights: array<vec4<f32>, 50> (25 pos × 2 vec4) +fn cnn_conv5x5_7to1( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + gray: f32, + weights: array<vec4<f32>, 50> +) -> f32 { + let step = 1.0 / resolution; + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = 0.0; + var pos = 0; + + for (var dy = -2; dy <= 2; dy++) { + for (var dx = -2; dx <= 2; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = textureSample(tex, samp, uv + offset); + + sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1); + pos += 2; + } + } + + return sum; +} + +// Source layer: 7→4 channels (vec4-optimized) +// Normalizes [0,1] input to [-1,1] internally +fn cnn_conv5x5_7to4_src( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + weights: array<vec4<f32>, 200> +) -> vec4<f32> { + let step = 1.0 / resolution; + + let original = (textureSample(tex, samp, uv) - 0.5) * 2.0; + let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)); + let uv_norm = (uv - 0.5) * 2.0; + let in1 = vec4<f32>(uv_norm, gray, 1.0); + + var sum = vec4<f32>(0.0); + var pos = 0; + + for (var dy = -2; dy <= 2; dy++) { + for (var dx = -2; dx <= 2; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0; + + sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1); + sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1); + sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1); + sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1); + pos += 8; + } + } + + return sum; +} diff --git a/cnn_v1/shaders/cnn_conv7x7.wgsl b/cnn_v1/shaders/cnn_conv7x7.wgsl new file mode 100644 index 0000000..e68d644 --- /dev/null +++ b/cnn_v1/shaders/cnn_conv7x7.wgsl @@ -0,0 +1,53 @@ +// 7x7 convolution with 49 samples +// Applies mat4 weights per sample + +fn cnn_conv7x7( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + weights: array<mat4x4<f32>, 49>, + bias: vec4<f32> +) -> vec4<f32> { + let step = 1.0 / resolution; + var sum = bias; + var idx = 0; + + for (var dy = -3; dy <= 3; dy++) { + for (var dx = -3; dx <= 3; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let sample = textureSample(tex, samp, uv + offset); + sum += weights[idx] * sample; + idx++; + } + } + + return sum; +} + +fn cnn_conv7x7_with_coord( + tex: texture_2d<f32>, + samp: sampler, + uv: vec2<f32>, + resolution: vec2<f32>, + rgba_weights: array<mat4x4<f32>, 49>, + coord_weights: mat2x4<f32>, + bias: vec4<f32> +) -> vec4<f32> { + let step = 1.0 / resolution; + var sum = bias; + + sum += coord_weights * uv; + + var idx = 0; + for (var dy = -3; dy <= 3; dy++) { + for (var dx = -3; dx <= 3; dx++) { + let offset = vec2<f32>(f32(dx), f32(dy)) * step; + let rgba = textureSample(tex, samp, uv + offset); + sum += rgba_weights[idx] * rgba; + idx++; + } + } + + return sum; +} diff --git a/cnn_v1/shaders/cnn_layer.wgsl b/cnn_v1/shaders/cnn_layer.wgsl new file mode 100644 index 0000000..cbd1686 --- /dev/null +++ b/cnn_v1/shaders/cnn_layer.wgsl @@ -0,0 +1,55 @@ +// CNN layer shader - uses modular convolution snippets +// Supports multi-pass rendering with residual connections +// DO NOT EDIT - Generated by train_cnn.py + +@group(0) @binding(0) var smplr: sampler; +@group(0) @binding(1) var txt: texture_2d<f32>; + +#include "common_uniforms" +#include "cnn_activation" +#include "cnn_conv3x3" +#include "cnn_conv5x5" +#include "cnn_weights_generated" + +struct CNNLayerParams { + layer_index: i32, + blend_amount: f32, + _pad: vec2<f32>, +}; + +@group(0) @binding(2) var<uniform> uniforms: CommonUniforms; +@group(0) @binding(3) var<uniform> params: CNNLayerParams; +@group(0) @binding(4) var original_input: texture_2d<f32>; + +@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> { + var pos = array<vec2<f32>, 3>( + vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0) + ); + return vec4<f32>(pos[i], 0.0, 1.0); +} + +@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> { + // Match PyTorch linspace + let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0); + let original_raw = textureSample(original_input, smplr, uv); + let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1] + let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0; + var result = vec4<f32>(0.0); + + // Layer 0: 7→4 (RGBD output, normalizes [0,1] input) + if (params.layer_index == 0) { + result = cnn_conv5x5_7to4_src(txt, smplr, uv, uniforms.resolution, weights_layer0); + result = cnn_tanh(result); + } + else if (params.layer_index == 1) { + result = cnn_conv3x3_7to4(txt, smplr, uv, uniforms.resolution, gray, weights_layer1); + result = cnn_tanh(result); // Keep in [-1,1] + } + else if (params.layer_index == 2) { + let sum = cnn_conv3x3_7to1(txt, smplr, uv, uniforms.resolution, gray, weights_layer2); + let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation + result = vec4<f32>(gray_out, gray_out, gray_out, 1.0); + return mix(original_raw, result, params.blend_amount); // [0,1] + } + return result; // [-1,1] +} diff --git a/cnn_v1/shaders/cnn_weights_generated.wgsl b/cnn_v1/shaders/cnn_weights_generated.wgsl new file mode 100644 index 0000000..510f86f --- /dev/null +++ b/cnn_v1/shaders/cnn_weights_generated.wgsl @@ -0,0 +1,302 @@ +// Auto-generated CNN weights (vec4-optimized) +// DO NOT EDIT - Generated by train_cnn.py + +const weights_layer0: array<vec4<f32>, 200> = array( + vec4<f32>(0.235493, 0.070711, -0.007171, 0.029242), + vec4<f32>(0.010796, -0.007094, 0.104870, -0.001741), + vec4<f32>(-0.363645, 0.625662, 0.044248, 0.046890), + vec4<f32>(0.016731, -0.099652, 0.198682, -0.002050), + vec4<f32>(-0.738196, -1.196639, -0.153794, 0.059818), + vec4<f32>(-0.012392, 0.206094, -1.159788, 0.001624), + vec4<f32>(-0.089846, -0.097056, 0.533546, -0.256308), + vec4<f32>(0.052460, 0.007740, -0.025518, -0.011569), + vec4<f32>(0.024563, -0.123127, -0.189236, -0.034605), + vec4<f32>(0.027494, 0.077022, -0.073083, -0.001741), + vec4<f32>(0.127897, -1.191688, -0.289229, -0.057213), + vec4<f32>(-0.017651, -0.095915, -0.540725, -0.002050), + vec4<f32>(0.459141, 1.047422, 1.008783, 0.082279), + vec4<f32>(-0.148789, 0.141891, 0.964934, 0.001624), + vec4<f32>(-0.458732, -0.253084, 0.429181, -0.267647), + vec4<f32>(0.029582, 0.043901, -0.332350, -0.011569), + vec4<f32>(-0.089206, -0.379760, -0.267976, -0.033062), + vec4<f32>(-0.059616, 0.042331, -0.297211, -0.001741), + vec4<f32>(0.347450, 0.349807, -0.107598, -0.038193), + vec4<f32>(-0.054979, -0.022737, 0.368773, -0.002050), + vec4<f32>(1.185666, 2.203693, 1.743948, 0.015765), + vec4<f32>(-0.004807, 0.138734, 2.114184, 0.001624), + vec4<f32>(-0.397312, -0.423930, 0.436068, -0.309529), + vec4<f32>(-0.025822, 0.061618, -0.358850, -0.011569), + vec4<f32>(0.031591, -0.133625, -0.210201, -0.058735), + vec4<f32>(0.026377, 0.074180, -0.075918, -0.001741), + vec4<f32>(-0.632064, -0.365984, -0.183357, -0.064294), + vec4<f32>(-0.038233, -0.027135, -0.529794, -0.002050), + vec4<f32>(-0.079942, -0.108489, 0.284420, 0.068003), + vec4<f32>(-0.033783, 0.131316, -0.006431, 0.001624), + vec4<f32>(-0.096003, -0.037157, 0.523401, -0.332369), + vec4<f32>(0.098362, 0.049597, 0.024988, -0.011569), + vec4<f32>(-0.042374, 0.215371, 0.044488, -0.079190), + vec4<f32>(-0.108483, 0.244548, 0.195395, -0.001741), + vec4<f32>(0.121079, 0.214838, 0.292411, -0.013912), + vec4<f32>(0.098564, -0.117552, 0.392438, -0.002050), + vec4<f32>(-0.994368, -0.526871, 0.165568, 0.006371), + vec4<f32>(-0.142932, 0.234835, -0.612723, 0.001624), + vec4<f32>(-0.430247, -0.230031, 0.035994, -0.340101), + vec4<f32>(-0.134622, -0.045299, -0.264801, -0.011569), + vec4<f32>(-0.116651, 0.042012, -0.004781, 0.018667), + vec4<f32>(0.000405, -0.068494, 0.084279, -0.001741), + vec4<f32>(0.180754, -0.853766, -0.384955, 0.013426), + vec4<f32>(0.038369, 0.010519, -0.437544, -0.002050), + vec4<f32>(0.373661, 0.677625, 0.617145, -0.028541), + vec4<f32>(0.071383, 0.012678, 0.734573, 0.001624), + vec4<f32>(-0.187586, -0.167658, 0.445526, -0.213674), + vec4<f32>(-0.054012, -0.048233, -0.111101, -0.011569), + vec4<f32>(-0.329708, 0.124956, 0.150447, 0.038372), + vec4<f32>(0.042139, -0.014901, 0.056693, -0.001741), + vec4<f32>(0.547166, 1.493724, 0.572366, 0.044038), + vec4<f32>(-0.055818, 0.022352, 1.209448, -0.002050), + vec4<f32>(-0.669255, -0.481531, -0.593402, 0.125846), + vec4<f32>(-0.086191, -0.012315, -0.692654, 0.001624), + vec4<f32>(-0.667836, -0.543086, 0.253854, -0.236805), + vec4<f32>(0.045048, 0.047535, -0.607491, -0.011569), + vec4<f32>(-0.262418, 0.247133, 0.225155, -0.084126), + vec4<f32>(0.017065, 0.007371, 0.103683, -0.001741), + vec4<f32>(0.216644, 1.179116, 0.436799, 0.041116), + vec4<f32>(0.006571, 0.012147, 0.674660, -0.002050), + vec4<f32>(0.290965, -0.022340, -0.616338, 0.021808), + vec4<f32>(-0.091234, -0.016764, 0.116976, 0.001624), + vec4<f32>(-0.689736, -0.685681, 0.342797, -0.213249), + vec4<f32>(0.040683, 0.038921, -0.663171, -0.011569), + vec4<f32>(-0.150412, 0.018053, -0.103426, 0.026070), + vec4<f32>(0.016183, -0.090006, 0.028738, -0.001741), + vec4<f32>(0.851827, -0.499315, 0.146696, 0.047324), + vec4<f32>(0.059725, 0.031269, 0.184268, -0.002050), + vec4<f32>(0.160719, -0.309456, -0.432633, -0.021171), + vec4<f32>(-0.060075, -0.052701, -0.248520, 0.001624), + vec4<f32>(-0.217727, 0.354527, 0.663356, -0.267530), + vec4<f32>(-0.032714, 0.000761, 0.246687, -0.011569), + vec4<f32>(0.077123, 0.069934, 0.077986, 0.004388), + vec4<f32>(-0.107897, 0.103689, 0.072698, -0.001741), + vec4<f32>(-0.216285, -0.206663, -0.497913, -0.019433), + vec4<f32>(0.042063, -0.036315, -0.306115, -0.002050), + vec4<f32>(0.351038, 0.116104, -0.046132, 0.022280), + vec4<f32>(-0.026460, -0.025197, 0.286924, 0.001624), + vec4<f32>(-0.480131, -0.253209, -0.259724, -0.353796), + vec4<f32>(-0.069436, -0.026651, -0.285359, -0.011569), + vec4<f32>(0.225811, -0.092313, -0.152689, 0.007505), + vec4<f32>(0.120530, 0.012846, -0.020303, -0.001741), + vec4<f32>(0.305262, 0.699468, 0.474383, -0.002565), + vec4<f32>(-0.036377, 0.008052, 0.424588, -0.002050), + vec4<f32>(0.557323, 0.489104, 0.312243, 0.072877), + vec4<f32>(0.096476, -0.012612, 0.586454, 0.001624), + vec4<f32>(-0.370964, -0.252666, 0.235903, -0.299915), + vec4<f32>(-0.066341, -0.008435, -0.158507, -0.011569), + vec4<f32>(0.070604, -0.016186, -0.079075, 0.015055), + vec4<f32>(0.042533, -0.085281, -0.014053, -0.001741), + vec4<f32>(-1.115748, -0.531544, -0.207050, -0.040691), + vec4<f32>(0.010035, -0.008330, -0.718958, -0.002050), + vec4<f32>(-1.404958, -2.000416, -1.884062, 0.014171), + vec4<f32>(0.019375, -0.078894, -1.999592, 0.001624), + vec4<f32>(-1.144367, -0.681485, 0.145197, -0.310542), + vec4<f32>(0.071912, -0.001021, -0.817277, -0.011569), + vec4<f32>(-0.018298, 0.109930, -0.067419, -0.031281), + vec4<f32>(0.072086, -0.047123, -0.018405, -0.001741), + vec4<f32>(-2.926982, -5.479454, -1.936543, 0.034851), + vec4<f32>(0.005592, 0.052238, -4.695754, -0.002050), + vec4<f32>(0.504616, -0.384917, -0.623795, 0.009371), + vec4<f32>(-0.105685, -0.049385, -0.154266, 0.001624), + vec4<f32>(-1.428979, -0.829611, 0.160294, -0.239524), + vec4<f32>(0.054180, -0.058797, -0.939519, -0.011569), + vec4<f32>(0.088147, -0.158820, -0.199674, -0.083067), + vec4<f32>(0.073984, -0.059593, -0.103344, -0.001741), + vec4<f32>(0.465084, 2.259005, 0.899806, -0.010464), + vec4<f32>(0.058231, -0.075668, 1.383652, -0.002050), + vec4<f32>(-0.162736, -0.899540, -0.559890, 0.066380), + vec4<f32>(0.029594, 0.036117, -0.780812, 0.001624), + vec4<f32>(-0.605431, 0.342970, 0.671602, -0.313734), + vec4<f32>(0.072950, 0.058100, 0.232742, -0.011569), + vec4<f32>(0.161941, -0.017279, -0.010904, -0.041589), + vec4<f32>(-0.118079, 0.090886, 0.001212, -0.001741), + vec4<f32>(-0.136354, 0.155269, 0.058437, -0.043499), + vec4<f32>(0.029368, 0.079326, -0.060807, -0.002050), + vec4<f32>(0.222824, 0.267939, 0.010260, 0.093258), + vec4<f32>(-0.091763, 0.028527, 0.290062, 0.001624), + vec4<f32>(-0.584501, -0.074002, -0.187352, -0.247388), + vec4<f32>(-0.067679, -0.036398, -0.237425, -0.011569), + vec4<f32>(-0.026121, -0.231360, 0.002505, -0.096021), + vec4<f32>(0.073173, -0.059323, -0.128630, -0.001741), + vec4<f32>(-0.118509, -0.931686, -0.328151, 0.027222), + vec4<f32>(0.006670, -0.094619, -0.605555, -0.002050), + vec4<f32>(0.260254, 0.186958, 0.235441, -0.030871), + vec4<f32>(0.111987, -0.056380, 0.227175, 0.001624), + vec4<f32>(0.012446, -0.068683, 0.273271, -0.315052), + vec4<f32>(-0.020011, 0.046984, 0.026316, -0.011569), + vec4<f32>(0.149830, 0.108146, 0.141757, 0.040947), + vec4<f32>(-0.060874, -0.004303, 0.196782, -0.001741), + vec4<f32>(1.031257, 1.493831, 0.443644, -0.089572), + vec4<f32>(-0.035087, 0.049431, 1.193984, -0.002050), + vec4<f32>(-0.204666, -0.340174, -0.045684, 0.053997), + vec4<f32>(0.000214, -0.073696, -0.299299, 0.001624), + vec4<f32>(-1.040674, -0.828753, 0.007912, -0.326534), + vec4<f32>(0.040669, -0.036526, -0.794626, -0.011569), + vec4<f32>(-0.018212, -0.031610, 0.259871, -0.041978), + vec4<f32>(0.021055, -0.061307, -0.004348, -0.001741), + vec4<f32>(0.002720, 0.570871, 0.371837, -0.076940), + vec4<f32>(0.023420, 0.006175, 0.318983, -0.002050), + vec4<f32>(0.259713, 0.294528, 0.907401, 0.043367), + vec4<f32>(-0.087576, -0.053953, 0.273380, 0.001624), + vec4<f32>(-1.177213, -0.464727, 0.211285, -0.266637), + vec4<f32>(0.075274, -0.007404, -0.703821, -0.011569), + vec4<f32>(-0.089204, -0.053316, 0.280138, -0.056155), + vec4<f32>(0.030981, -0.005136, 0.038455, -0.001741), + vec4<f32>(0.936459, -0.196866, 0.270033, -0.096884), + vec4<f32>(0.025329, -0.032176, 0.473732, -0.002050), + vec4<f32>(0.312348, 0.234105, 0.580837, 0.099177), + vec4<f32>(0.019877, -0.096514, 0.450075, 0.001624), + vec4<f32>(-1.099700, -0.203693, 0.157253, -0.331450), + vec4<f32>(-0.033353, -0.072074, -0.453590, -0.011569), + vec4<f32>(-0.084598, -0.039735, 0.162495, -0.070988), + vec4<f32>(-0.038491, 0.071525, 0.034601, -0.001741), + vec4<f32>(-0.199528, -0.475454, -0.297979, 0.037322), + vec4<f32>(-0.003106, 0.003258, -0.475664, -0.002050), + vec4<f32>(-0.282845, 0.058921, -0.300971, -0.011632), + vec4<f32>(-0.102320, 0.065302, -0.035173, 0.001624), + vec4<f32>(-0.515296, 0.497936, 0.313751, -0.245144), + vec4<f32>(-0.126936, 0.016721, 0.233370, -0.011569), + vec4<f32>(-0.220154, 0.069414, 0.194344, 0.000786), + vec4<f32>(0.037788, -0.095021, -0.055585, -0.001741), + vec4<f32>(-0.186244, 0.434960, 0.138978, -0.017604), + vec4<f32>(0.014466, 0.055976, 0.306540, -0.002050), + vec4<f32>(0.000614, -0.087365, -0.327816, 0.025776), + vec4<f32>(0.227096, -0.143725, -0.046319, 0.001624), + vec4<f32>(0.468607, -0.441809, -0.025186, -0.260166), + vec4<f32>(0.018770, -0.067388, -0.240128, -0.011569), + vec4<f32>(-0.013968, 0.032027, -0.111361, -0.023976), + vec4<f32>(0.041929, -0.033460, 0.001994, -0.001741), + vec4<f32>(0.005203, -0.837762, -0.287991, -0.026139), + vec4<f32>(-0.077592, 0.021388, -0.524153, -0.002050), + vec4<f32>(0.250865, 0.313428, -0.248465, 0.059517), + vec4<f32>(0.034922, -0.054528, 0.257107, 0.001624), + vec4<f32>(0.010692, -0.067238, 0.233031, -0.310017), + vec4<f32>(0.176915, -0.059644, 0.016072, -0.011569), + vec4<f32>(0.016422, 0.016187, -0.037382, -0.083725), + vec4<f32>(0.002691, -0.110865, -0.012957, -0.001741), + vec4<f32>(0.095561, 0.396829, 0.128803, 0.037097), + vec4<f32>(0.019823, 0.093399, 0.310928, -0.002050), + vec4<f32>(-0.193791, -0.079385, 0.332894, 0.039734), + vec4<f32>(0.119291, -0.053947, 0.020449, 0.001624), + vec4<f32>(-0.446965, -0.003325, 0.231982, -0.298212), + vec4<f32>(0.063248, -0.060392, -0.103558, -0.011569), + vec4<f32>(-0.044501, -0.246630, -0.254448, -0.025872), + vec4<f32>(0.044620, -0.074284, -0.183828, -0.001741), + vec4<f32>(-0.369636, -0.171104, -0.485456, -0.085980), + vec4<f32>(-0.053131, 0.016452, -0.377567, -0.002050), + vec4<f32>(-0.183644, -0.028271, 0.226453, 0.010102), + vec4<f32>(0.039391, -0.132828, -0.009034, 0.001624), + vec4<f32>(-0.644046, -0.335421, 0.011161, -0.222670), + vec4<f32>(0.091183, 0.005457, -0.472058, -0.011569), + vec4<f32>(0.045107, 0.080623, -0.132791, 0.064920), + vec4<f32>(-0.110745, 0.109524, 0.092569, -0.001741), + vec4<f32>(0.064397, 0.190407, 0.257845, 0.024637), + vec4<f32>(-0.042557, 0.128625, 0.317239, -0.002050), + vec4<f32>(-0.362482, 0.271381, -0.115412, 0.103104), + vec4<f32>(0.088766, 0.042583, 0.069687, 0.001624), + vec4<f32>(-0.353634, 0.554832, 0.442496, -0.351794), + vec4<f32>(-0.140207, -0.064649, 0.346336, -0.011569) +); + +const weights_layer1: array<vec4<f32>, 72> = array( + vec4<f32>(-0.059078, -0.087833, -0.048345, -0.276761), + vec4<f32>(-0.101904, 0.058647, -0.405575, -0.064215), + vec4<f32>(-0.382952, 0.579364, -0.051813, -0.155723), + vec4<f32>(-0.140997, -0.006771, 0.212267, 0.120289), + vec4<f32>(-0.152651, -0.134768, -0.076617, -0.506104), + vec4<f32>(0.089304, 0.078492, 0.541122, 0.129289), + vec4<f32>(0.739323, -0.014103, -0.012980, -0.112747), + vec4<f32>(-0.089971, -0.088661, -0.520901, 0.158290), + vec4<f32>(0.819725, 2.866048, 0.080441, 0.380885), + vec4<f32>(0.035196, 0.028422, -0.748029, -0.064215), + vec4<f32>(-0.551722, 0.995924, -0.203047, -0.220742), + vec4<f32>(-0.081721, 0.039584, 0.581791, 0.120289), + vec4<f32>(-0.752329, -0.482903, -0.317275, 0.515372), + vec4<f32>(-0.087637, 0.040969, 0.481261, 0.129289), + vec4<f32>(0.532382, -0.653574, 0.078268, 0.139585), + vec4<f32>(-0.089350, -0.072701, -1.289249, 0.158290), + vec4<f32>(0.384272, -0.051717, 0.428463, -0.006561), + vec4<f32>(0.034003, 0.036653, -0.778556, -0.064215), + vec4<f32>(-0.788796, 0.332339, -0.181283, -0.213141), + vec4<f32>(0.196044, -0.062422, 0.724631, 0.120289), + vec4<f32>(-0.416297, -0.520778, -0.009510, -0.304383), + vec4<f32>(0.094475, -0.033135, 0.942838, 0.129289), + vec4<f32>(0.887455, 0.054078, 0.193434, 0.268549), + vec4<f32>(-0.055369, -0.042953, -0.172902, 0.158290), + vec4<f32>(0.419144, -0.159019, 0.189637, -0.235703), + vec4<f32>(-0.098285, 0.021026, -0.041846, -0.064215), + vec4<f32>(-1.009575, 0.934207, -0.120383, -0.243756), + vec4<f32>(-0.054562, 0.123804, 0.004157, 0.120289), + vec4<f32>(-0.504099, 0.696545, -0.850290, 0.493131), + vec4<f32>(-0.090043, -0.020600, -1.148702, 0.129289), + vec4<f32>(0.302269, -0.662429, 0.315052, -0.276341), + vec4<f32>(-0.084626, -0.029208, -0.799132, 0.158290), + vec4<f32>(0.318365, 2.531235, 0.349606, 0.231242), + vec4<f32>(0.053525, -0.031474, -0.570432, -0.064215), + vec4<f32>(-0.635031, 0.498836, 0.009884, -0.465079), + vec4<f32>(0.059087, 0.038415, 0.009928, 0.120289), + vec4<f32>(-0.522592, -3.781285, 0.418296, -0.608186), + vec4<f32>(0.100879, -0.083891, 1.653884, 0.129289), + vec4<f32>(0.258571, 2.590279, 0.221239, -0.143175), + vec4<f32>(0.121409, -0.084177, -1.397735, 0.158290), + vec4<f32>(0.907284, -0.034063, 0.573987, -0.125626), + vec4<f32>(-0.017610, -0.059485, -0.242599, -0.064215), + vec4<f32>(-0.748146, 0.686047, -0.074510, -0.248879), + vec4<f32>(-0.034986, -0.121423, -0.406087, 0.120289), + vec4<f32>(-0.559352, -2.921763, -0.718019, -0.764524), + vec4<f32>(0.165658, 0.097044, 0.773885, 0.129289), + vec4<f32>(0.006276, -0.801820, 0.215264, 0.115919), + vec4<f32>(0.081513, -0.023028, -0.590423, 0.158290), + vec4<f32>(-0.207850, 0.088171, -0.173170, 0.351969), + vec4<f32>(-0.042732, -0.024059, -0.087492, -0.064215), + vec4<f32>(-0.711148, 0.312318, -0.145549, -0.113749), + vec4<f32>(0.053038, 0.093166, -0.473856, 0.120289), + vec4<f32>(-0.343481, -0.137305, -0.340862, 0.445920), + vec4<f32>(-0.070473, -0.024914, -0.735660, 0.129289), + vec4<f32>(0.212955, -0.200508, 0.105125, -0.165284), + vec4<f32>(-0.123633, 0.052941, 0.099918, 0.158290), + vec4<f32>(0.362468, -0.709693, 0.281097, -0.155976), + vec4<f32>(-0.034566, 0.002014, 0.443026, -0.064215), + vec4<f32>(-0.346208, 1.179972, -0.563868, -0.424647), + vec4<f32>(0.012676, -0.023351, -0.703819, 0.120289), + vec4<f32>(-0.476282, -0.001002, -0.456911, -0.143433), + vec4<f32>(0.061018, -0.051173, -0.992671, 0.129289), + vec4<f32>(0.340925, -0.869046, 0.333377, -0.070414), + vec4<f32>(0.022279, 0.022837, -0.389711, 0.158290), + vec4<f32>(0.217347, -0.092030, -0.004346, 0.209850), + vec4<f32>(-0.116637, -0.096003, -0.333961, -0.064215), + vec4<f32>(-0.105262, 0.443411, -0.443104, 0.032732), + vec4<f32>(0.014939, 0.058855, -0.723723, 0.120289), + vec4<f32>(-0.598907, -0.166341, -0.635385, 0.463685), + vec4<f32>(0.151976, 0.049510, 0.155364, 0.129289), + vec4<f32>(0.138981, -0.109141, 0.272429, 0.190495), + vec4<f32>(-0.005729, 0.020860, -0.062157, 0.158290) +); + +const weights_layer2: array<vec4<f32>, 18> = array( + vec4<f32>(0.043207, -0.056041, 0.131565, 0.116278), + vec4<f32>(-0.038849, -0.028105, -0.112979, 0.023741), + vec4<f32>(-0.010112, -0.085145, 0.257510, 0.245113), + vec4<f32>(0.041108, 0.049255, -0.082008, 0.023741), + vec4<f32>(0.012368, -0.035856, 0.018924, 0.174452), + vec4<f32>(0.052554, 0.039427, -0.279445, 0.023741), + vec4<f32>(-0.160061, -0.232735, 0.256951, 0.208887), + vec4<f32>(-0.088352, 0.100106, 0.103566, 0.023741), + vec4<f32>(-0.406607, -1.336396, 0.454171, 0.310834), + vec4<f32>(-0.061166, 0.105463, 1.572779, 0.023741), + vec4<f32>(-0.188413, -0.523344, 0.082813, 0.209113), + vec4<f32>(0.052509, -0.069748, -0.065008, 0.023741), + vec4<f32>(-0.124016, 0.005237, 0.177859, 0.138953), + vec4<f32>(0.072167, 0.070582, -0.209545, 0.023741), + vec4<f32>(-0.384457, -0.186386, 0.273595, 0.235457), + vec4<f32>(-0.032392, -0.086899, -0.006561, 0.023741), + vec4<f32>(-0.195800, 0.017395, 0.023080, 0.181437), + vec4<f32>(-0.035524, -0.095398, -0.204917, 0.023741) +); + diff --git a/cnn_v1/src/cnn_v1_effect.cc b/cnn_v1/src/cnn_v1_effect.cc new file mode 100644 index 0000000..1f44619 --- /dev/null +++ b/cnn_v1/src/cnn_v1_effect.cc @@ -0,0 +1,129 @@ +// CNN post-processing effect implementation +// Neural network-based stylization with modular WGSL + +#include "cnn_v1_effect.h" +#include "gpu/bind_group_builder.h" +#include "gpu/effect.h" +#include "gpu/pipeline_builder.h" +#include "gpu/post_process_helper.h" +#include "gpu/sampler_cache.h" +#include "gpu/shader_composer.h" +#include "gpu/shaders.h" + +// Create custom pipeline with 5 bindings (includes original texture) +static WGPURenderPipeline create_cnn_pipeline(WGPUDevice device, + WGPUTextureFormat format, + const char* shader_code) { + WGPUBindGroupLayout bgl = + BindGroupLayoutBuilder() + .sampler(0, WGPUShaderStage_Fragment) + .texture(1, WGPUShaderStage_Fragment) + .uniform(2, WGPUShaderStage_Vertex | WGPUShaderStage_Fragment) + .uniform(3, WGPUShaderStage_Fragment) + .texture(4, WGPUShaderStage_Fragment) + .build(device); + + WGPURenderPipeline pipeline = RenderPipelineBuilder(device) + .shader(shader_code) + .bind_group_layout(bgl) + .format(format) + .build(); + + wgpuBindGroupLayoutRelease(bgl); + return pipeline; +} + +CNNv1Effect::CNNv1Effect(const GpuContext& ctx) + : PostProcessEffect(ctx), layer_index_(0), total_layers_(1), + blend_amount_(1.0f), input_view_(nullptr), original_view_(nullptr), + bind_group_(nullptr) { + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); +} + +CNNv1Effect::CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params) + : PostProcessEffect(ctx), layer_index_(params.layer_index), + total_layers_(params.total_layers), blend_amount_(params.blend_amount), + input_view_(nullptr), original_view_(nullptr), bind_group_(nullptr) { + pipeline_ = + create_cnn_pipeline(ctx_.device, ctx_.format, cnn_layer_shader_wgsl); +} + +void CNNv1Effect::init(MainSequence* demo) { + PostProcessEffect::init(demo); + demo_ = demo; + params_buffer_.init(ctx_.device); + + // Register auxiliary texture for layer 0 (width_/height_ set by resize()) + if (layer_index_ == 0) { + demo_->register_auxiliary_texture("captured_frame", width_, height_); + } + + // Initialize uniforms BEFORE any bind group creation + uniforms_.update(ctx_.queue, get_common_uniforms()); + + CNNv1LayerParams params = {layer_index_, blend_amount_, {0.0f, 0.0f}}; + params_buffer_.update(ctx_.queue, params); +} + +void CNNv1Effect::resize(int width, int height) { + if (width == width_ && height == height_) + return; + + PostProcessEffect::resize(width, height); + + // Only layer 0 owns the captured_frame texture + if (layer_index_ == 0 && demo_) { + demo_->resize_auxiliary_texture("captured_frame", width, height); + } +} + +void CNNv1Effect::render(WGPURenderPassEncoder pass, + const CommonPostProcessUniforms& uniforms) { + if (!bind_group_) { + fprintf(stderr, "CNN render: no bind_group\n"); + return; + } + + float effective_blend = blend_amount_; + if (beat_modulated_) { + effective_blend = blend_amount_ * uniforms.beat_phase * beat_scale_; + } + + CNNv1LayerParams params = {layer_index_, effective_blend, {0.0f, 0.0f}}; + params_buffer_.update(ctx_.queue, params); + + wgpuRenderPassEncoderSetPipeline(pass, pipeline_); + wgpuRenderPassEncoderSetBindGroup(pass, 0, bind_group_, 0, nullptr); + wgpuRenderPassEncoderDraw(pass, 3, 1, 0, 0); +} + +void CNNv1Effect::update_bind_group(WGPUTextureView input_view) { + input_view_ = input_view; + + // Update common uniforms (CRITICAL for UV calculation!) + uniforms_.update(ctx_.queue, get_common_uniforms()); + + // All layers: get captured frame (original input from layer 0) + if (demo_) { + original_view_ = demo_->get_auxiliary_view("captured_frame"); + } + + // Create bind group with original texture + if (bind_group_) + wgpuBindGroupRelease(bind_group_); + + WGPUBindGroupLayout bgl = wgpuRenderPipelineGetBindGroupLayout(pipeline_, 0); + // Use clamp (not repeat) to match PyTorch Conv2d zero-padding behavior + WGPUSampler sampler = + SamplerCache::Get().get_or_create(ctx_.device, SamplerCache::clamp()); + + bind_group_ = + BindGroupBuilder() + .sampler(0, sampler) + .texture(1, input_view_) + .buffer(2, uniforms_.get().buffer, uniforms_.get().size) + .buffer(3, params_buffer_.get().buffer, params_buffer_.get().size) + .texture(4, original_view_ ? original_view_ : input_view_) + .build(ctx_.device, bgl); +} diff --git a/cnn_v1/src/cnn_v1_effect.h b/cnn_v1/src/cnn_v1_effect.h new file mode 100644 index 0000000..e820275 --- /dev/null +++ b/cnn_v1/src/cnn_v1_effect.h @@ -0,0 +1,53 @@ +// CNN post-processing effect header +// Multi-layer neural network stylization + +#pragma once +#include "gpu/effect.h" +#include "gpu/uniform_helper.h" + +struct CNNv1LayerParams { + int layer_index; + float blend_amount; // Blend: mix(input, output, blend_amount) + float _pad[2]; +}; +static_assert(sizeof(CNNv1LayerParams) == 16); + +struct CNNv1EffectParams { + int layer_index = 0; // Which layer to render (0-based) + int total_layers = 1; // Total number of layers in the CNN + float blend_amount = 1.0f; // Final blend with original input +}; + +class CNNv1Effect : public PostProcessEffect { + public: + explicit CNNv1Effect(const GpuContext& ctx); + explicit CNNv1Effect(const GpuContext& ctx, const CNNv1EffectParams& params); + + void init(MainSequence* demo) override; + void resize(int width, int height) override; + void render(WGPURenderPassEncoder pass, + const CommonPostProcessUniforms& uniforms) override; + void update_bind_group(WGPUTextureView input_view) override; + + // Layer 0 needs framebuffer capture for original input + bool needs_framebuffer_capture() const override { + return layer_index_ == 0; + } + + void set_beat_modulation(bool enabled, float scale = 1.0f) { + beat_modulated_ = enabled; + beat_scale_ = scale; + } + + private: + int layer_index_; + int total_layers_; + float blend_amount_; + bool beat_modulated_ = false; + float beat_scale_ = 1.0f; + WGPUTextureView input_view_; + WGPUTextureView original_view_; + UniformBuffer<CNNv1LayerParams> params_buffer_; + WGPUBindGroup bind_group_; + MainSequence* demo_ = nullptr; +}; diff --git a/cnn_v1/training/train_cnn.py b/cnn_v1/training/train_cnn.py new file mode 100755 index 0000000..4171dcb --- /dev/null +++ b/cnn_v1/training/train_cnn.py @@ -0,0 +1,943 @@ +#!/usr/bin/env python3 +""" +CNN Training Script for Image-to-Image Transformation + +Trains a convolutional neural network on multiple input/target image pairs. + +Usage: + # Training + python3 train_cnn.py --input input_dir/ --target target_dir/ [options] + + # Inference (generate ground truth) + python3 train_cnn.py --infer image.png --export-only checkpoint.pth --output result.png + +Example: + python3 train_cnn.py --input ./input --target ./output --layers 3 --epochs 100 + python3 train_cnn.py --infer input.png --export-only checkpoints/checkpoint_epoch_10000.pth +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +import numpy as np +import cv2 +import os +import sys +import argparse +import glob + + +class ImagePairDataset(Dataset): + """Dataset for loading matching input/target image pairs""" + + def __init__(self, input_dir, target_dir, transform=None): + self.input_dir = input_dir + self.target_dir = target_dir + self.transform = transform + + # Find all images in input directory + input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] + self.image_pairs = [] + + for pattern in input_patterns: + input_files = glob.glob(os.path.join(input_dir, pattern)) + for input_path in input_files: + filename = os.path.basename(input_path) + # Try to find matching target with same name but any supported extension + target_path = None + for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: + base_name = os.path.splitext(filename)[0] + candidate = os.path.join(target_dir, f"{base_name}.{ext}") + if os.path.exists(candidate): + target_path = candidate + break + + if target_path: + self.image_pairs.append((input_path, target_path)) + + if not self.image_pairs: + raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") + + print(f"Found {len(self.image_pairs)} matching image pairs") + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, idx): + input_path, target_path = self.image_pairs[idx] + + # Load RGBD input (4 channels: RGB + Depth) + input_img = Image.open(input_path).convert('RGBA') + target_img = Image.open(target_path).convert('RGB') + + if self.transform: + input_img = self.transform(input_img) + target_img = self.transform(target_img) + + return input_img, target_img + + +class PatchDataset(Dataset): + """Dataset for extracting salient patches from image pairs""" + + def __init__(self, input_dir, target_dir, patch_size=32, patches_per_image=64, + detector='harris', transform=None): + self.input_dir = input_dir + self.target_dir = target_dir + self.patch_size = patch_size + self.patches_per_image = patches_per_image + self.detector = detector + self.transform = transform + + # Find all image pairs + input_patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] + self.image_pairs = [] + + for pattern in input_patterns: + input_files = glob.glob(os.path.join(input_dir, pattern)) + for input_path in input_files: + filename = os.path.basename(input_path) + target_path = None + for ext in ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG']: + base_name = os.path.splitext(filename)[0] + candidate = os.path.join(target_dir, f"{base_name}.{ext}") + if os.path.exists(candidate): + target_path = candidate + break + + if target_path: + self.image_pairs.append((input_path, target_path)) + + if not self.image_pairs: + raise ValueError(f"No matching image pairs found between {input_dir} and {target_dir}") + + print(f"Found {len(self.image_pairs)} image pairs") + print(f"Extracting {patches_per_image} patches per image using {detector} detector") + print(f"Total patches: {len(self.image_pairs) * patches_per_image}") + + def __len__(self): + return len(self.image_pairs) * self.patches_per_image + + def _detect_salient_points(self, img_array): + """Detect salient points using specified detector""" + gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) + h, w = gray.shape + half_patch = self.patch_size // 2 + + if self.detector == 'harris': + # Harris corner detection + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch) + elif self.detector == 'fast': + # FAST feature detection + fast = cv2.FastFeatureDetector_create(threshold=20) + keypoints = fast.detect(gray, None) + corners = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints[:self.patches_per_image * 2]]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + elif self.detector == 'shi-tomasi': + # Shi-Tomasi corner detection (goodFeaturesToTrack with different params) + corners = cv2.goodFeaturesToTrack(gray, self.patches_per_image * 2, + qualityLevel=0.01, minDistance=half_patch, + useHarrisDetector=False) + elif self.detector == 'gradient': + # High-gradient regions + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + gradient_mag = np.sqrt(grad_x**2 + grad_y**2) + + # Find top gradient locations + threshold = np.percentile(gradient_mag, 95) + y_coords, x_coords = np.where(gradient_mag > threshold) + + if len(x_coords) > self.patches_per_image * 2: + indices = np.random.choice(len(x_coords), self.patches_per_image * 2, replace=False) + x_coords = x_coords[indices] + y_coords = y_coords[indices] + + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) if len(corners) > 0 else None + else: + raise ValueError(f"Unknown detector: {self.detector}") + + # Fallback to random if no corners found + if corners is None or len(corners) == 0: + x_coords = np.random.randint(half_patch, w - half_patch, self.patches_per_image) + y_coords = np.random.randint(half_patch, h - half_patch, self.patches_per_image) + corners = np.array([[x, y] for x, y in zip(x_coords, y_coords)]) + corners = corners.reshape(-1, 1, 2) + + # Filter valid corners (within bounds) + valid_corners = [] + for corner in corners: + x, y = int(corner[0][0]), int(corner[0][1]) + if half_patch <= x < w - half_patch and half_patch <= y < h - half_patch: + valid_corners.append((x, y)) + if len(valid_corners) >= self.patches_per_image: + break + + # Fill with random if not enough + while len(valid_corners) < self.patches_per_image: + x = np.random.randint(half_patch, w - half_patch) + y = np.random.randint(half_patch, h - half_patch) + valid_corners.append((x, y)) + + return valid_corners + + def __getitem__(self, idx): + img_idx = idx // self.patches_per_image + patch_idx = idx % self.patches_per_image + + input_path, target_path = self.image_pairs[img_idx] + + # Load images + input_img = Image.open(input_path).convert('RGBA') + target_img = Image.open(target_path).convert('RGB') + + # Detect salient points (use input image for detection) + input_array = np.array(input_img)[:, :, :3] # Use RGB for detection + corners = self._detect_salient_points(input_array) + + # Extract patch at specified index + x, y = corners[patch_idx] + half_patch = self.patch_size // 2 + + # Crop patches + input_patch = input_img.crop((x - half_patch, y - half_patch, + x + half_patch, y + half_patch)) + target_patch = target_img.crop((x - half_patch, y - half_patch, + x + half_patch, y + half_patch)) + + if self.transform: + input_patch = self.transform(input_patch) + target_patch = self.transform(target_patch) + + return input_patch, target_patch + + +class SimpleCNN(nn.Module): + """CNN for RGBD→RGB with 7-channel input (RGBD + UV + gray) + + Internally computes grayscale, expands to 3-channel RGB output. + """ + + def __init__(self, num_layers=1, kernel_sizes=None): + super(SimpleCNN, self).__init__() + + if kernel_sizes is None: + kernel_sizes = [3] * num_layers + + assert len(kernel_sizes) == num_layers, "kernel_sizes must match num_layers" + + self.kernel_sizes = kernel_sizes + self.layers = nn.ModuleList() + + for i, kernel_size in enumerate(kernel_sizes): + padding = kernel_size // 2 + if i < num_layers - 1: + # Inner layers: 7→4 (RGBD output) + self.layers.append(nn.Conv2d(7, 4, kernel_size=kernel_size, padding=padding, bias=True)) + else: + # Final layer: 7→1 (grayscale output) + self.layers.append(nn.Conv2d(7, 1, kernel_size=kernel_size, padding=padding, bias=True)) + + def forward(self, x, return_intermediates=False): + # x: [B,4,H,W] - RGBD input (D = 1/z) + B, C, H, W = x.shape + + intermediates = [] if return_intermediates else None + + # Normalize RGBD to [-1,1] + x_norm = (x - 0.5) * 2.0 + + # Compute normalized coordinates [-1,1] + y_coords = torch.linspace(-1, 1, H, device=x.device).view(1,1,H,1).expand(B,1,H,W) + x_coords = torch.linspace(-1, 1, W, device=x.device).view(1,1,1,W).expand(B,1,H,W) + + # Compute grayscale from original RGB (Rec.709) and normalize to [-1,1] + gray = 0.2126*x[:,0:1] + 0.7152*x[:,1:2] + 0.0722*x[:,2:3] # [B,1,H,W] in [0,1] + gray = (gray - 0.5) * 2.0 # [-1,1] + + # Layer 0 + layer0_input = torch.cat([x_norm, x_coords, y_coords, gray], dim=1) # [B,7,H,W] + out = self.layers[0](layer0_input) # [B,4,H,W] + out = torch.tanh(out) # [-1,1] + if return_intermediates: + intermediates.append(out.clone()) + + # Inner layers + for i in range(1, len(self.layers)-1): + layer_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[i](layer_input) + out = torch.tanh(out) + if return_intermediates: + intermediates.append(out.clone()) + + # Final layer (grayscale→RGB) + final_input = torch.cat([out, x_coords, y_coords, gray], dim=1) + out = self.layers[-1](final_input) # [B,1,H,W] grayscale + out = torch.sigmoid(out) # Map to [0,1] with smooth gradients + final_out = out.expand(-1, 3, -1, -1) # [B,3,H,W] expand to RGB + + if return_intermediates: + return final_out, intermediates + return final_out + + +def generate_layer_shader(output_path, num_layers, kernel_sizes): + """Generate cnn_layer.wgsl with proper layer switches""" + + with open(output_path, 'w') as f: + f.write("// CNN layer shader - uses modular convolution snippets\n") + f.write("// Supports multi-pass rendering with residual connections\n") + f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") + f.write("@group(0) @binding(0) var smplr: sampler;\n") + f.write("@group(0) @binding(1) var txt: texture_2d<f32>;\n\n") + f.write("#include \"common_uniforms\"\n") + f.write("#include \"cnn_activation\"\n") + + # Include necessary conv functions + conv_sizes = set(kernel_sizes) + for ks in sorted(conv_sizes): + f.write(f"#include \"cnn_conv{ks}x{ks}\"\n") + f.write("#include \"cnn_weights_generated\"\n\n") + + f.write("struct CNNLayerParams {\n") + f.write(" layer_index: i32,\n") + f.write(" blend_amount: f32,\n") + f.write(" _pad: vec2<f32>,\n") + f.write("};\n\n") + f.write("@group(0) @binding(2) var<uniform> uniforms: CommonUniforms;\n") + f.write("@group(0) @binding(3) var<uniform> params: CNNLayerParams;\n") + f.write("@group(0) @binding(4) var original_input: texture_2d<f32>;\n\n") + f.write("@vertex fn vs_main(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {\n") + f.write(" var pos = array<vec2<f32>, 3>(\n") + f.write(" vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vec2<f32>(-1.0, 3.0)\n") + f.write(" );\n") + f.write(" return vec4<f32>(pos[i], 0.0, 1.0);\n") + f.write("}\n\n") + f.write("@fragment fn fs_main(@builtin(position) p: vec4<f32>) -> @location(0) vec4<f32> {\n") + f.write(" // Match PyTorch linspace\n") + f.write(" let uv = (p.xy - 0.5) / (uniforms.resolution - 1.0);\n") + f.write(" let original_raw = textureSample(original_input, smplr, uv);\n") + f.write(" let original = (original_raw - 0.5) * 2.0; // Normalize to [-1,1]\n") + f.write(" let gray = (dot(original_raw.rgb, vec3<f32>(0.2126, 0.7152, 0.0722)) - 0.5) * 2.0;\n") + f.write(" var result = vec4<f32>(0.0);\n\n") + + # Generate layer switches + for layer_idx in range(num_layers): + is_final = layer_idx == num_layers - 1 + ks = kernel_sizes[layer_idx] + conv_fn = f"cnn_conv{ks}x{ks}_7to4" if not is_final else f"cnn_conv{ks}x{ks}_7to1" + + if layer_idx == 0: + conv_fn_src = f"cnn_conv{ks}x{ks}_7to4_src" + f.write(f" // Layer 0: 7→4 (RGBD output, normalizes [0,1] input)\n") + f.write(f" if (params.layer_index == {layer_idx}) {{\n") + f.write(f" result = {conv_fn_src}(txt, smplr, uv, uniforms.resolution, weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result);\n") + f.write(f" }}\n") + elif not is_final: + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" result = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n") + f.write(f" result = cnn_tanh(result); // Keep in [-1,1]\n") + f.write(f" }}\n") + else: + f.write(f" else if (params.layer_index == {layer_idx}) {{\n") + f.write(f" let sum = {conv_fn}(txt, smplr, uv, uniforms.resolution, gray, weights_layer{layer_idx});\n") + f.write(f" let gray_out = 1.0 / (1.0 + exp(-sum)); // Sigmoid activation\n") + f.write(f" result = vec4<f32>(gray_out, gray_out, gray_out, 1.0);\n") + f.write(f" return mix(original_raw, result, params.blend_amount); // [0,1]\n") + f.write(f" }}\n") + + f.write(" return result; // [-1,1]\n") + f.write("}\n") + + +def export_weights_to_wgsl(model, output_path, kernel_sizes): + """Export trained weights to WGSL format (vec4-optimized)""" + + with open(output_path, 'w') as f: + f.write("// Auto-generated CNN weights (vec4-optimized)\n") + f.write("// DO NOT EDIT - Generated by train_cnn.py\n\n") + + for i, layer in enumerate(model.layers): + weights = layer.weight.data.cpu().numpy() + bias = layer.bias.data.cpu().numpy() + out_ch, in_ch, kh, kw = weights.shape + num_positions = kh * kw + + is_final = (i == len(model.layers) - 1) + + if is_final: + # Final layer: 7→1, structure: array<vec4<f32>, 18> (9 pos × 2 vec4) + # Input: [rgba, uv_gray_1] → 2 vec4s per position + f.write(f"const weights_layer{i}: array<vec4<f32>, {num_positions * 2}> = array(\n") + for pos in range(num_positions): + row, col = pos // kw, pos % kw + # First vec4: [w0, w1, w2, w3] (rgba) + v0 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4)] + # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) + v1 = [f"{weights[0, in_c, row, col]:.6f}" for in_c in range(4, 7)] + v1.append(f"{bias[0] / num_positions:.6f}") + f.write(f" vec4<f32>({', '.join(v0)}),\n") + f.write(f" vec4<f32>({', '.join(v1)})") + f.write(",\n" if pos < num_positions-1 else "\n") + f.write(");\n\n") + else: + # Inner layers: 7→4, structure: array<vec4<f32>, 72> (36 entries × 2 vec4) + # Each filter: 2 vec4s for [rgba][uv_gray_1] inputs + num_vec4s = num_positions * 4 * 2 + f.write(f"const weights_layer{i}: array<vec4<f32>, {num_vec4s}> = array(\n") + for pos in range(num_positions): + row, col = pos // kw, pos % kw + for out_c in range(4): + # First vec4: [w0, w1, w2, w3] (rgba) + v0 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4)] + # Second vec4: [w4, w5, w6, bias] (uv, gray, 1) + v1 = [f"{weights[out_c, in_c, row, col]:.6f}" for in_c in range(4, 7)] + v1.append(f"{bias[out_c] / num_positions:.6f}") + idx = (pos * 4 + out_c) * 2 + f.write(f" vec4<f32>({', '.join(v0)}),\n") + f.write(f" vec4<f32>({', '.join(v1)})") + f.write(",\n" if idx < num_vec4s-2 else "\n") + f.write(");\n\n") + + +def generate_conv_base_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4() function for inner layers (vec4-optimized)""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Inner layers: 7→4 channels (vec4-optimized)\n") + f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") + f.write(f"fn cnn_conv{k}x{k}_7to4(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" gray: f32,\n") + f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n") + f.write(f") -> vec4<f32> {{\n") + f.write(f" let step = 1.0 / resolution;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n\n") + f.write(f" var sum = vec4<f32>(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n") + f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n") + + # Accumulate + f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n") + f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n") + f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n") + f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n") + f.write(f" pos += 8;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + +def generate_conv_src_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to4_src() function for layer 0 (vec4-optimized)""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Source layer: 7→4 channels (vec4-optimized)\n") + f.write(f"// Normalizes [0,1] input to [-1,1] internally\n") + f.write(f"fn cnn_conv{k}x{k}_7to4_src(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" weights: array<vec4<f32>, {num_positions * 8}>\n") + f.write(f") -> vec4<f32> {{\n") + f.write(f" let step = 1.0 / resolution;\n\n") + + # Normalize center pixel for gray channel + f.write(f" let original = (textureSample(tex, samp, uv) - 0.5) * 2.0;\n") + f.write(f" let gray = dot(original.rgb, vec3<f32>(0.2126, 0.7152, 0.0722));\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n") + f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n") + + f.write(f" var sum = vec4<f32>(0.0);\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = (textureSample(tex, samp, uv + offset) - 0.5) * 2.0;\n\n") + + # Accumulate with dot products (unrolled) + f.write(f" sum.r += dot(weights[pos+0], rgbd) + dot(weights[pos+1], in1);\n") + f.write(f" sum.g += dot(weights[pos+2], rgbd) + dot(weights[pos+3], in1);\n") + f.write(f" sum.b += dot(weights[pos+4], rgbd) + dot(weights[pos+5], in1);\n") + f.write(f" sum.a += dot(weights[pos+6], rgbd) + dot(weights[pos+7], in1);\n") + f.write(f" pos += 8;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + +def generate_conv_final_function(kernel_size, output_path): + """Generate cnn_conv{K}x{K}_7to1() function for final layer (vec4-optimized)""" + + k = kernel_size + num_positions = k * k + radius = k // 2 + + with open(output_path, 'a') as f: + f.write(f"\n// Final layer: 7→1 channel (vec4-optimized)\n") + f.write(f"// Assumes 'tex' is already normalized to [-1,1]\n") + f.write(f"// Returns raw sum (activation applied at call site)\n") + f.write(f"fn cnn_conv{k}x{k}_7to1(\n") + f.write(f" tex: texture_2d<f32>,\n") + f.write(f" samp: sampler,\n") + f.write(f" uv: vec2<f32>,\n") + f.write(f" resolution: vec2<f32>,\n") + f.write(f" gray: f32,\n") + f.write(f" weights: array<vec4<f32>, {num_positions * 2}>\n") + f.write(f") -> f32 {{\n") + f.write(f" let step = 1.0 / resolution;\n") + f.write(f" let uv_norm = (uv - 0.5) * 2.0;\n") + f.write(f" let in1 = vec4<f32>(uv_norm, gray, 1.0);\n\n") + f.write(f" var sum = 0.0;\n") + f.write(f" var pos = 0;\n\n") + + # Convolution loop + f.write(f" for (var dy = -{radius}; dy <= {radius}; dy++) {{\n") + f.write(f" for (var dx = -{radius}; dx <= {radius}; dx++) {{\n") + f.write(f" let offset = vec2<f32>(f32(dx), f32(dy)) * step;\n") + f.write(f" let rgbd = textureSample(tex, samp, uv + offset);\n\n") + + # Accumulate with dot products + f.write(f" sum += dot(weights[pos], rgbd) + dot(weights[pos+1], in1);\n") + f.write(f" pos += 2;\n") + f.write(f" }}\n") + f.write(f" }}\n\n") + + f.write(f" return sum;\n") + f.write(f"}}\n") + + +def train(args): + """Main training loop""" + + # Setup device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Prepare dataset + if args.patch_size: + # Patch-based training (preserves natural scale) + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + dataset = PatchDataset(args.input, args.target, + patch_size=args.patch_size, + patches_per_image=args.patches_per_image, + detector=args.detector, + transform=transform) + else: + # Full-image training (resize mode) + transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + ]) + dataset = ImagePairDataset(args.input, args.target, transform=transform) + + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Parse kernel sizes + kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] + if len(kernel_sizes) == 1 and args.layers > 1: + kernel_sizes = kernel_sizes * args.layers + + # Create model + model = SimpleCNN(num_layers=args.layers, kernel_sizes=kernel_sizes).to(device) + + # Loss and optimizer + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) + + # Resume from checkpoint + start_epoch = 0 + if args.resume: + if os.path.exists(args.resume): + print(f"Loading checkpoint from {args.resume}...") + checkpoint = torch.load(args.resume, map_location=device) + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Resumed from epoch {start_epoch}") + else: + print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") + + # Compute valid center region (exclude conv padding borders) + num_layers = args.layers + border = num_layers # Each 3x3 layer needs 1px, accumulates across layers + + # Early stopping setup + loss_history = [] + early_stop_triggered = False + + # Training loop + print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") + print(f"Computing loss on center region only (excluding {border}px border)") + if args.early_stop_patience > 0: + print(f"Early stopping: patience={args.early_stop_patience}, eps={args.early_stop_eps}") + + for epoch in range(start_epoch, args.epochs): + epoch_loss = 0.0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + inputs, targets = inputs.to(device), targets.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + + # Only compute loss on center pixels with valid neighborhoods + if border > 0 and outputs.shape[2] > 2*border and outputs.shape[3] > 2*border: + outputs_center = outputs[:, :, border:-border, border:-border] + targets_center = targets[:, :, border:-border, border:-border] + loss = criterion(outputs_center, targets_center) + else: + loss = criterion(outputs, targets) + + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(dataloader) + if (epoch + 1) % 10 == 0: + print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") + + # Early stopping check + if args.early_stop_patience > 0: + loss_history.append(avg_loss) + if len(loss_history) >= args.early_stop_patience: + oldest_loss = loss_history[-args.early_stop_patience] + loss_change = abs(avg_loss - oldest_loss) + if loss_change < args.early_stop_eps: + print(f"Early stopping triggered at epoch {epoch+1}") + print(f"Loss change over last {args.early_stop_patience} epochs: {loss_change:.8f} < {args.early_stop_eps}") + early_stop_triggered = True + break + + # Save checkpoint + if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: + checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') + torch.save({ + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict(), + 'loss': avg_loss, + 'kernel_sizes': kernel_sizes, + 'num_layers': args.layers + }, checkpoint_path) + print(f"Saved checkpoint to {checkpoint_path}") + + # Export weights and shader + output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' + print(f"\nExporting weights to {output_path}...") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_weights_to_wgsl(model, output_path, kernel_sizes) + + # Generate layer shader + shader_dir = os.path.dirname(output_path) + shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') + print(f"Generating layer shader to {shader_path}...") + generate_layer_shader(shader_path, args.layers, kernel_sizes) + + # Generate conv shader files for all kernel sizes + for ks in set(kernel_sizes): + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist + if not os.path.exists(conv_path): + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") + continue + + # File exists, check for missing functions + with open(conv_path, 'r') as f: + content = f.read() + + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate 7to1 final layer if missing + if f"cnn_conv{ks}x{ks}_7to1" not in content: + generate_conv_final_function(ks, conv_path) + print(f"Added 7to1 variant to {conv_path}") + + print("Training complete!") + + +def export_from_checkpoint(checkpoint_path, output_path=None): + """Export WGSL files from checkpoint without training""" + + if not os.path.exists(checkpoint_path): + print(f"Error: Checkpoint file '{checkpoint_path}' not found") + sys.exit(1) + + print(f"Loading checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + kernel_sizes = checkpoint['kernel_sizes'] + num_layers = checkpoint['num_layers'] + + # Recreate model + model = SimpleCNN(num_layers=num_layers, kernel_sizes=kernel_sizes) + model.load_state_dict(checkpoint['model_state']) + + # Export weights + output_path = output_path or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' + print(f"Exporting weights to {output_path}...") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_weights_to_wgsl(model, output_path, kernel_sizes) + + # Generate layer shader + shader_dir = os.path.dirname(output_path) + shader_path = os.path.join(shader_dir, 'cnn_layer.wgsl') + print(f"Generating layer shader to {shader_path}...") + generate_layer_shader(shader_path, num_layers, kernel_sizes) + + # Generate conv shader files for all kernel sizes + for ks in set(kernel_sizes): + conv_path = os.path.join(shader_dir, f'cnn_conv{ks}x{ks}.wgsl') + + # Create file with header if it doesn't exist + if not os.path.exists(conv_path): + print(f"Creating {conv_path}...") + with open(conv_path, 'w') as f: + f.write(f"// {ks}x{ks} convolution (vec4-optimized)\n") + generate_conv_base_function(ks, conv_path) + generate_conv_src_function(ks, conv_path) + generate_conv_final_function(ks, conv_path) + print(f"Generated complete {conv_path}") + continue + + # File exists, check for missing functions + with open(conv_path, 'r') as f: + content = f.read() + + # Generate base 7to4 if missing + if f"cnn_conv{ks}x{ks}_7to4" not in content: + generate_conv_base_function(ks, conv_path) + print(f"Added base 7to4 to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate _src variant if missing + if f"cnn_conv{ks}x{ks}_7to4_src" not in content: + generate_conv_src_function(ks, conv_path) + print(f"Added _src variant to {conv_path}") + with open(conv_path, 'r') as f: + content = f.read() + + # Generate 7to1 final layer if missing + if f"cnn_conv{ks}x{ks}_7to1" not in content: + generate_conv_final_function(ks, conv_path) + print(f"Added 7to1 variant to {conv_path}") + + print("Export complete!") + + +def infer_from_checkpoint(checkpoint_path, input_path, output_path, patch_size=32, save_intermediates=None, zero_weights=False, debug_hex=False): + """Run sliding-window inference to match WGSL shader behavior + + Outputs RGBA PNG (RGB from model + alpha from input). + """ + + if not os.path.exists(checkpoint_path): + print(f"Error: Checkpoint '{checkpoint_path}' not found") + sys.exit(1) + + if not os.path.exists(input_path): + print(f"Error: Input image '{input_path}' not found") + sys.exit(1) + + print(f"Loading checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Reconstruct model + model = SimpleCNN( + num_layers=checkpoint['num_layers'], + kernel_sizes=checkpoint['kernel_sizes'] + ) + model.load_state_dict(checkpoint['model_state']) + + # Debug: Zero out all weights and biases + if zero_weights: + print("DEBUG: Zeroing out all weights and biases") + for layer in model.layers: + with torch.no_grad(): + layer.weight.zero_() + layer.bias.zero_() + + model.eval() + + # Load image + print(f"Loading input image: {input_path}") + img = Image.open(input_path).convert('RGBA') + img_tensor = transforms.ToTensor()(img).unsqueeze(0) # [1,4,H,W] + W, H = img.size + + # Process full image with sliding window (matches WGSL shader) + print(f"Processing full image ({W}×{H}) with sliding window...") + with torch.no_grad(): + if save_intermediates: + output_tensor, intermediates = model(img_tensor, return_intermediates=True) + else: + output_tensor = model(img_tensor) # [1,3,H,W] RGB + + # Convert to numpy and append alpha + output = output_tensor.squeeze(0).permute(1, 2, 0).numpy() # [H,W,3] RGB + alpha = img_tensor[0, 3:4, :, :].permute(1, 2, 0).numpy() # [H,W,1] alpha from input + output_rgba = np.concatenate([output, alpha], axis=2) # [H,W,4] RGBA + + # Debug: print first 8 pixels as hex + if debug_hex: + output_u8 = (output_rgba * 255).astype(np.uint8) + print("First 8 pixels (RGBA hex):") + for i in range(min(8, output_u8.shape[0] * output_u8.shape[1])): + y, x = i // output_u8.shape[1], i % output_u8.shape[1] + r, g, b, a = output_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + + # Save final output as RGBA + print(f"Saving output to: {output_path}") + os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) + output_img = Image.fromarray((output_rgba * 255).astype(np.uint8), mode='RGBA') + output_img.save(output_path) + + # Save intermediates if requested + if save_intermediates: + os.makedirs(save_intermediates, exist_ok=True) + print(f"Saving {len(intermediates)} intermediate layers to: {save_intermediates}") + for layer_idx, layer_tensor in enumerate(intermediates): + # Convert [-1,1] to [0,1] for visualization + layer_data = (layer_tensor.squeeze(0).permute(1, 2, 0).numpy() + 1.0) * 0.5 + layer_u8 = (layer_data.clip(0, 1) * 255).astype(np.uint8) + + # Debug: print first 8 pixels as hex + if debug_hex: + print(f"Layer {layer_idx} first 8 pixels (RGBA hex):") + for i in range(min(8, layer_u8.shape[0] * layer_u8.shape[1])): + y, x = i // layer_u8.shape[1], i % layer_u8.shape[1] + if layer_u8.shape[2] == 4: + r, g, b, a = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}{a:02X}") + else: + r, g, b = layer_u8[y, x] + print(f" [{i}] 0x{r:02X}{g:02X}{b:02X}") + + # Save all 4 channels for intermediate layers + if layer_data.shape[2] == 4: + layer_img = Image.fromarray(layer_u8, mode='RGBA') + else: + layer_img = Image.fromarray(layer_u8) + layer_path = os.path.join(save_intermediates, f'layer_{layer_idx}.png') + layer_img.save(layer_path) + print(f" Saved layer {layer_idx} to {layer_path}") + + print("Done!") + + +def main(): + parser = argparse.ArgumentParser(description='Train CNN for image-to-image transformation') + parser.add_argument('--input', help='Input image directory (training) or single image (inference)') + parser.add_argument('--target', help='Target image directory') + parser.add_argument('--layers', type=int, default=1, help='Number of CNN layers (default: 1)') + parser.add_argument('--kernel_sizes', default='3', help='Comma-separated kernel sizes (default: 3)') + parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (default: 100)') + parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') + parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') + parser.add_argument('--output', help='Output path (WGSL for training/export, PNG for inference)') + parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') + parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') + parser.add_argument('--resume', help='Resume from checkpoint file') + parser.add_argument('--export-only', help='Export WGSL from checkpoint without training') + parser.add_argument('--infer', help='Run inference on single image (requires --export-only for checkpoint)') + parser.add_argument('--patch-size', type=int, help='Extract patches of this size (e.g., 32) instead of resizing (default: None = resize to 256x256)') + parser.add_argument('--patches-per-image', type=int, default=64, help='Number of patches to extract per image (default: 64)') + parser.add_argument('--detector', default='harris', choices=['harris', 'fast', 'shi-tomasi', 'gradient'], + help='Salient point detector for patch extraction (default: harris)') + parser.add_argument('--early-stop-patience', type=int, default=0, help='Stop if loss changes less than eps over N epochs (default: 0 = disabled)') + parser.add_argument('--early-stop-eps', type=float, default=1e-6, help='Loss change threshold for early stopping (default: 1e-6)') + parser.add_argument('--save-intermediates', help='Directory to save intermediate layer outputs (inference only)') + parser.add_argument('--zero-weights', action='store_true', help='Zero out all weights/biases during inference (debug only)') + parser.add_argument('--debug-hex', action='store_true', help='Print first 8 pixels as hex (debug only)') + + args = parser.parse_args() + + # Inference mode + if args.infer: + checkpoint = args.export_only + if not checkpoint: + print("Error: --infer requires --export-only <checkpoint>") + sys.exit(1) + output_path = args.output or 'inference_output.png' + patch_size = args.patch_size or 32 + infer_from_checkpoint(checkpoint, args.infer, output_path, patch_size, args.save_intermediates, args.zero_weights, args.debug_hex) + return + + # Export-only mode + if args.export_only: + export_from_checkpoint(args.export_only, args.output) + return + + # Validate directories for training + if not args.input or not args.target: + print("Error: --input and --target required for training (or use --export-only)") + sys.exit(1) + + if not os.path.isdir(args.input): + print(f"Error: Input directory '{args.input}' does not exist") + sys.exit(1) + + if not os.path.isdir(args.target): + print(f"Error: Target directory '{args.target}' does not exist") + sys.exit(1) + + train(args) + + +if __name__ == "__main__": + main() |
