summaryrefslogtreecommitdiff
path: root/cnn_v3
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3')
-rw-r--r--cnn_v3/docs/CNN_V3.md38
-rw-r--r--cnn_v3/docs/HOWTO.md170
-rw-r--r--cnn_v3/docs/HOW_TO_CNN.md173
-rw-r--r--cnn_v3/docs/cnn_v3_architecture.pngbin0 -> 254783 bytes
-rw-r--r--cnn_v3/docs/gen_architecture_png.py238
-rw-r--r--cnn_v3/shaders/cnn_v3_bottleneck.wgsl32
-rw-r--r--cnn_v3/src/cnn_v3_effect.cc2
-rw-r--r--cnn_v3/test_vectors.h310
-rw-r--r--cnn_v3/tools/index.html1
-rw-r--r--cnn_v3/tools/shaders.js18
-rw-r--r--cnn_v3/tools/tester.js45
-rw-r--r--cnn_v3/tools/weights.js4
-rw-r--r--cnn_v3/training/cnn_v3_utils.py45
-rw-r--r--cnn_v3/training/export_cnn_v3_weights.py44
-rw-r--r--cnn_v3/training/gen_test_vectors.py72
-rw-r--r--cnn_v3/training/infer_cnn_v3.py219
-rw-r--r--cnn_v3/training/train_cnn_v3.py74
17 files changed, 1111 insertions, 374 deletions
diff --git a/cnn_v3/docs/CNN_V3.md b/cnn_v3/docs/CNN_V3.md
index 4d58811..d775e2b 100644
--- a/cnn_v3/docs/CNN_V3.md
+++ b/cnn_v3/docs/CNN_V3.md
@@ -27,33 +27,7 @@ CNN v3 is a next-generation post-processing effect using:
### Pipeline Overview
-```
-G-Buffer (albedo, normal, depth, matID, UV)
- │
- ▼
- FiLM Conditioning
- (beat_time, audio_intensity, style_params)
- │ → γ[], β[] per channel
- ▼
- U-Net
- ┌─────────────────────────────────────────┐
- │ Encoder │
- │ enc0 (H×W, 4ch) ────────────skip──────┤
- │ ↓ down (avg pool 2×2) │
- │ enc1 (H/2×W/2, 8ch) ────────skip──────┤
- │ ↓ down │
- │ bottleneck (H/4×W/4, 8ch) │
- │ │
- │ Decoder │
- │ ↑ up (nearest ×2) + skip enc1 │
- │ dec1 (H/2×W/2, 4ch) │
- │ ↑ up + skip enc0 │
- │ dec0 (H×W, 4ch) │
- └─────────────────────────────────────────┘
- │
- ▼
- output RGBA (H×W)
-```
+![CNN v3 U-Net + FiLM Architecture](cnn_v3_architecture.png)
FiLM is applied **inside each encoder/decoder block**, after each convolution.
@@ -352,11 +326,11 @@ All f16, little-endian, same packing as v2 (`pack2x16float`).
|-----------|---------|------|-----------|
| enc0: Conv(20→4, 3×3) | 20×4×9=720 | +4 | 724 |
| enc1: Conv(4→8, 3×3) | 4×8×9=288 | +8 | 296 |
-| bottleneck: Conv(8→8, 1×1) | 8×8×1=64 | +8 | 72 |
+| bottleneck: Conv(8→8, 3×3, dil=2) | 8×8×9=576 | +8 | 584 |
| dec1: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 |
| dec0: Conv(8→4, 3×3) | 8×4×9=288 | +4 | 292 |
| FiLM MLP (5→16→40) | 5×16+16×40=720 | +16+40 | 776 |
-| **Total** | | | **~3.9 KB f16** |
+| **Total conv** | | | **~4.84 KB f16** |
Skip connections: dec1 input = 8ch (bottleneck) + 8ch (enc1 skip) = 16ch.
dec0 input = 4ch (dec1) + 4ch (enc0 skip) = 8ch.
@@ -541,7 +515,7 @@ class CNNv3(nn.Module):
nn.Conv2d(enc_channels[0], enc_channels[1], 3, padding=1),
])
# Bottleneck
- self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 1)
+ self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 3, padding=2, dilation=2)
# Decoder (skip connections: concat → double channels)
self.dec = nn.ModuleList([
nn.Conv2d(enc_channels[1]*2, enc_channels[0], 3, padding=1),
@@ -709,7 +683,7 @@ Parity results:
Pass 0: pack_gbuffer.wgsl — assemble G-buffer channels into storage texture
Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→4ch, 3×3)
Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (4→8ch, 3×3) + downsample
-Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 1×1)
+Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 3×3, dilation=2)
Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (16→4, 3×3)
Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (8→4, 3×3)
Pass 6: cnn_v3_output.wgsl — sigmoid + composite to framebuffer
@@ -816,7 +790,7 @@ Status bar shows which channels are loaded.
| `PACK_SHADER` | `STATIC_SHADER` | 20ch into feat_tex0 + feat_tex1 (rgba32uint each) |
| `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→4, 3×3) + FiLM + ReLU; writes enc0_tex |
| `ENC1_SHADER` | | Conv(4→8, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_tex (half-res) |
-| `BOTTLENECK_SHADER` | | Conv(8→8, 1×1) + FiLM + ReLU; writes bn_tex |
+| `BOTTLENECK_SHADER` | | Conv(8→8, 3×3, dilation=2) + ReLU; writes bn_tex |
| `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(16→4, 3×3) + FiLM + ReLU |
| `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(8→4, 3×3) + FiLM + ReLU |
| `OUTPUT_SHADER` | | Conv(4→4, 1×1) + sigmoid → composites to canvas |
diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md
index 5cfc371..9a3efdf 100644
--- a/cnn_v3/docs/HOWTO.md
+++ b/cnn_v3/docs/HOWTO.md
@@ -233,12 +233,13 @@ channel-dropout training.
```bash
python3 cnn_v3/training/pack_photo_sample.py \
- --photo cnn_v3/training/input/photo1.jpg \
+ --photo input/photo1.jpg \
+ --target target/photo1_styled.png \
--output dataset/photos/sample_001/
```
-The output `target.png` defaults to the input photo (no style). Copy in
-your stylized version as `target.png` before training.
+`--target` is required and must be a stylized ground-truth image at the same
+resolution as the photo. The script writes it as `target.png` in the sample dir.
### Dataset layout
@@ -285,10 +286,31 @@ python3 train_cnn_v3.py \
--patch-size 32 --detector random
```
+### Single-sample training
+
+Use `--single-sample <dir>` to train on one specific sample directory.
+Implies `--full-image` and `--batch-size 1` automatically.
+
+```bash
+# Pack input/target pair into a sample directory first
+python3 pack_photo_sample.py \
+ --photo input/photo1.png \
+ --target target/photo1_styled.png \
+ --output dataset/simple/sample_001/
+
+# Train on that sample only
+python3 train_cnn_v3.py \
+ --single-sample dataset/simple/sample_001/ \
+ --epochs 500
+```
+
+All other flags (`--epochs`, `--lr`, `--checkpoint-dir`, `--enc-channels`, etc.) work normally.
+
### Key flags
| Flag | Default | Notes |
|------|---------|-------|
+| `--single-sample DIR` | — | Train on one sample dir; implies `--full-image`, `--batch-size 1` |
| `--input DIR` | `training/dataset` | Root with `full/` or `simple/` subdirs |
| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
| `--patch-size N` | `64` | Patch crop size |
@@ -417,10 +439,10 @@ FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each fra
|-------|---------|------|-----------|
| enc0 | 20×4×9=720 | +4 | 724 |
| enc1 | 4×8×9=288 | +8 | 296 |
-| bottleneck | 8×8×1=64 | +8 | 72 |
+| bottleneck | 8×8×9=576 | +8 | 584 |
| dec1 | 16×4×9=576 | +4 | 580 |
| dec0 | 8×4×9=288 | +4 | 292 |
-| **Total** | | | **2064 f16 = ~4 KB** |
+| **Total** | | | **2476 f16 = ~4.84 KB** |
**Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`):
`SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`,
@@ -587,9 +609,145 @@ Visualization panel still works.
---
-## 10. See Also
+## 10. Python / WGSL Parity Check (infer_cnn_v3 + cnn_test)
+
+Two complementary tools for comparing PyTorch inference against the live WGSL
+compute shaders on the same input image.
+
+### 10a. infer_cnn_v3.py — PyTorch reference inference
+
+**Location:** `cnn_v3/training/infer_cnn_v3.py`
+
+Runs the trained `CNNv3` model in Python and saves the RGBA output as PNG.
+
+**Simple mode** (single PNG, geometry zeroed):
+```bash
+cd cnn_v3/training
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth
+```
+
+**Full mode** (sample directory with all G-buffer files):
+```bash
+python3 infer_cnn_v3.py dataset/simple/sample_000/ out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth
+```
+
+**Identity FiLM** — bypass MLP, use γ=1 β=0 (matches C++ `cnn_test` default):
+```bash
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth \
+ --identity-film
+```
+
+**Options:**
+
+| Flag | Default | Description |
+|------|---------|-------------|
+| `--checkpoint CKPT` | auto-find latest | Path to `.pth` checkpoint |
+| `--enc-channels C` | from checkpoint | `4,8` — must match training config |
+| `--cond F F F F F` | `0 0 0 0 0` | FiLM conditioning (beat_phase, beat_norm, audio, style0, style1) |
+| `--identity-film` | off | Bypass FiLM MLP, use γ=1 β=0 |
+| `--blend F` | `1.0` | Blend with albedo: 0=input, 1=CNN |
+| `--debug-hex` | off | Print first 8 output pixels as hex |
+
+In **simple mode**, geometry channels are zeroed: `normal=(0.5,0.5)` (oct-encodes
+to ≈(0,0,1)), `depth=0`, `matid=0`, `shadow=1`, `transp=0`.
+
+The checkpoint `config` dict (saved by `train_cnn_v3.py`) sets `enc_channels`
+and `film_cond_dim` automatically; `--enc-channels` is only needed if the
+checkpoint lacks a config key.
+
+---
+
+### 10b. cnn_test — WGSL / GPU reference inference
+
+**Location:** `tools/cnn_test.cc` **Binary:** `build/cnn_test`
+
+Packs the same 20-channel feature tensor as `infer_cnn_v3.py`, uploads it to
+GPU, runs the five `CNNv3Effect` compute passes, and saves the RGBA16Float
+output as PNG.
+
+**Build** (requires `DEMO_BUILD_TESTS=ON` or `DEMO_WORKSPACE=main`):
+```bash
+cmake -B build -DDEMO_BUILD_TESTS=ON && cmake --build build -j4 --target cnn_test
+```
+
+**Simple mode:**
+```bash
+./build/cnn_test photo.png out_gpu.png --weights workspaces/main/weights/cnn_v3_weights.bin
+```
+
+**Full mode** (sample directory):
+```bash
+./build/cnn_test dataset/simple/sample_000/albedo.png out_gpu.png \
+ --sample-dir dataset/simple/sample_000/ \
+ --weights workspaces/main/weights/cnn_v3_weights.bin
+```
+
+**Options:**
+
+| Flag | Description |
+|------|-------------|
+| `--sample-dir DIR` | Load all G-buffer files (albedo/normal/depth/matid/shadow/transp) |
+| `--weights FILE` | `cnn_v3_weights.bin` (uses asset-embedded weights if omitted) |
+| `--debug-hex` | Print first 8 output pixels as hex |
+| `--help` | Show usage |
+
+FiLM is always **identity** (γ=1, β=0) — matching the C++ `CNNv3Effect` default
+until GPU-side FiLM MLP evaluation is added.
+
+---
+
+### 10c. Side-by-side comparison
+
+For a pixel-accurate comparison, use `--identity-film` in Python and `--debug-hex`
+in both tools:
+
+```bash
+cd cnn_v3/training
+
+# 1. Python inference (identity FiLM)
+python3 infer_cnn_v3.py photo.png out_python.png \
+ --checkpoint checkpoints/checkpoint_epoch_200.pth \
+ --identity-film --debug-hex
+
+# 2. GPU inference (always identity FiLM)
+./build/cnn_test photo.png out_gpu.png \
+ --weights workspaces/main/weights/cnn_v3_weights.bin \
+ --debug-hex
+```
+
+Both tools print first 8 pixels in the same format:
+```
+ [0] 0x7F804000 (0.4980 0.5020 0.2510 0.0000)
+```
+
+**Expected delta:** ≤ 1/255 (≈ 4e-3) per channel, matching the parity test
+(`test_cnn_v3_parity`). Larger deltas indicate a weight mismatch — re-export
+with `export_cnn_v3_weights.py` and verify the `.bin` size is 4952 bytes.
+
+---
+
+### 10d. Feature format note
+
+Both tools pack features in **training format** ([0,1] oct-encoded normals),
+not the runtime `gbuf_pack.wgsl` format (which remaps normals to [-1,1]).
+This makes `infer_cnn_v3.py` ↔ `cnn_test` directly comparable.
+
+The live pipeline (`GBufferEffect → gbuf_pack.wgsl → CNNv3Effect`) uses [-1,1]
+normals — that is the intended inference distribution after a full training run
+with `--input-mode full` (Blender renders). For training on photos
+(`--input-mode simple`), [0,1] normals are correct since channel dropout
+teaches the network to handle absent geometry.
+
+---
+
+## 11. See Also
- `cnn_v3/docs/CNN_V3.md` — Full architecture design (U-Net, FiLM, feature layout)
- `doc/EFFECT_WORKFLOW.md` — General effect integration guide
- `cnn_v2/docs/CNN_V2.md` — Reference implementation (simpler, operational)
- `src/tests/gpu/test_demo_effects.cc` — GBufferEffect + GBufViewEffect tests
+- `src/tests/gpu/test_cnn_v3_parity.cc` — Zero/random weight parity tests
+- `cnn_v3/training/export_cnn_v3_weights.py` — Export trained checkpoint → `.bin`
diff --git a/cnn_v3/docs/HOW_TO_CNN.md b/cnn_v3/docs/HOW_TO_CNN.md
index 4966a61..09db97c 100644
--- a/cnn_v3/docs/HOW_TO_CNN.md
+++ b/cnn_v3/docs/HOW_TO_CNN.md
@@ -28,26 +28,13 @@ CNN v3 is a 2-level U-Net with FiLM conditioning, designed to run in real-time a
**Architecture:**
-```
-Input: 20-channel G-buffer feature textures (rgba32uint)
- │
- enc0 ──── Conv(20→4, 3×3) + FiLM + ReLU ┐ full res
- │ ↘ skip │
- enc1 ──── AvgPool2×2 + Conv(4→8, 3×3) + FiLM ┐ ½ res
- │ ↘ skip │
- bottleneck AvgPool2×2 + Conv(8→8, 1×1) + ReLU ¼ res (no FiLM)
- │ │
- dec1 ←── upsample×2 + cat(enc1 skip) + Conv(16→4, 3×3) + FiLM
- │ │ ½ res
- dec0 ←── upsample×2 + cat(enc0 skip) + Conv(8→4, 3×3) + FiLM + sigmoid
- full res → RGBA output
-```
+![CNN v3 U-Net + FiLM Architecture](cnn_v3_architecture.png)
**FiLM MLP:** `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net.
- Input: `[beat_phase, beat_norm, audio_intensity, style_p0, style_p1]`
- Output: 40 γ/β values controlling style across all 4 FiLM layers
-**Weight budget:** ~3.9 KB f16 (fits ≤6 KB target)
+**Weight budget:** ~4.84 KB f16 conv (fits ≤6 KB target)
**Two data paths:**
- **Simple mode** — real photos with zeroed geometric channels (normal, depth, matid)
@@ -58,13 +45,13 @@ Input: 20-channel G-buffer feature textures (rgba32uint)
```
photos/Blender → pack → dataset/ → train_cnn_v3.py → checkpoint.pth
- export_cnn_v3_weights.py
- ┌─────────┴──────────┐
- cnn_v3_weights.bin cnn_v3_film_mlp.bin
- │
- CNNv3Effect::upload_weights()
- │
- demo / HTML tool
+ export_cnn_v3_weights.py [--html]
+ ┌──────────┴────────────┬──────────────┐
+ cnn_v3_weights.bin cnn_v3_film_mlp.bin weights.js
+ │ (HTML tool
+ CNNv3Effect::upload_weights() defaults)
+ │
+ demo
```
---
@@ -107,15 +94,6 @@ The network learns the mapping `albedo → target`. If you pass the same image a
input and target, the network learns identity (useful as sanity check, not for real
training). Confirm `target.png` looks correct before running training.
-**Alternative — pack without a target yet:**
-```bash
-python3 pack_photo_sample.py \
- --photo /path/to/photo.png \
- --output dataset/simple/sample_001/
-# target.png defaults to a copy of the input; replace it before training:
-cp my_stylized_version.png dataset/simple/sample_001/target.png
-```
-
**Batch packing:**
```bash
for f in photos/*.png; do
@@ -284,55 +262,78 @@ The U-Net conv weights and FiLM MLP train **jointly** in a single run. No separa
### Prerequisites
+`train_cnn_v3.py` and `export_cnn_v3_weights.py` carry inline `uv` dependency metadata
+(`# /// script`). Use `uv run` — no manual `pip install` needed:
+
```bash
-pip install torch torchvision pillow numpy opencv-python
cd cnn_v3/training
+uv run train_cnn_v3.py --input dataset/ --epochs 1 --patch-size 32 --detector random
```
-**With `uv` (no pip needed):** dependencies are declared inline in `train_cnn_v3.py`
-and installed automatically on first run:
+**Without `uv` (manual pip):**
```bash
+pip install torch torchvision pillow numpy opencv-python
cd cnn_v3/training
-uv run train_cnn_v3.py --input dataset/ --epochs 1 --patch-size 32 --detector random
+python3 train_cnn_v3.py ...
```
+The pack scripts (`pack_photo_sample.py`, `pack_blender_sample.py`) and
+`gen_test_vectors.py` do **not** have uv metadata — run them with `python3` directly
+(they only need `numpy`, `pillow`, and optionally `openexr`).
+
### Quick-start commands
**Smoke test — 1 epoch, validates end-to-end without GPU:**
```bash
-python3 train_cnn_v3.py --input dataset/ --epochs 1 \
+uv run train_cnn_v3.py --input dataset/ --epochs 1 \
--patch-size 32 --detector random
```
**Standard photo training (patch-based):**
```bash
-python3 train_cnn_v3.py \
+uv run train_cnn_v3.py \
--input dataset/ \
--input-mode simple \
- --epochs 200
+ --epochs 200 \
+ --edge-loss-weight 0.1 \
+ --film-warmup-epochs 50
```
**Blender G-buffer training:**
```bash
-python3 train_cnn_v3.py \
+uv run train_cnn_v3.py \
--input dataset/ \
--input-mode full \
- --epochs 200
+ --epochs 200 \
+ --edge-loss-weight 0.1 \
+ --film-warmup-epochs 50
```
**Full-image mode (better global coherence, slower):**
```bash
-python3 train_cnn_v3.py \
+uv run train_cnn_v3.py \
--input dataset/ \
--input-mode full \
--full-image --image-size 256 \
--epochs 500
```
+**Single-sample training (overfit on one input/target pair):**
+```bash
+# Pack first
+./gen_sample.sh input/photo.png target/photo_styled.png dataset/simple/sample_001/
+
+# Train — --full-image and --batch-size 1 are implied
+uv run train_cnn_v3.py \
+ --single-sample dataset/simple/sample_001/ \
+ --epochs 500
+```
+
### Flag reference
| Flag | Default | Notes |
|------|---------|-------|
+| `--single-sample DIR` | — | Train on one sample dir; implies `--full-image`, `--batch-size 1` |
| `--input DIR` | `training/dataset` | Dataset root; always set explicitly |
| `--input-mode` | `simple` | `simple`=photos, `full`=Blender G-buffer |
| `--epochs N` | 200 | 500 recommended for full-image mode |
@@ -340,7 +341,8 @@ python3 train_cnn_v3.py \
| `--lr F` | 1e-3 | Reduce to 1e-4 if loss oscillates or NaN |
| `--patch-size N` | 64 | Smaller = faster epoch, less spatial context |
| `--patches-per-image N` | 256 | Reduce for small datasets |
-| `--detector` | `harris` | `random` for smoke tests; `shi-tomasi` as alternative |
+| `--detector` | `harris` | `random` for smoke tests; also `shi-tomasi`, `fast`, `gradient` |
+| `--patch-search-window N` | 0 | Search ±N px in target to find best alignment (grayscale MSE) per patch; 0=disabled. Use when source and target are not perfectly co-registered (e.g. photo + hand-painted target). Offsets cached at dataset init. |
| `--channel-dropout-p F` | 0.3 | Lower if all samples have geometry (Blender only) |
| `--full-image` | off | Resize full image instead of patch crops |
| `--image-size N` | 256 | Resize target; only used with `--full-image` |
@@ -348,12 +350,15 @@ python3 train_cnn_v3.py \
| `--film-cond-dim N` | 5 | Must match `CNNv3FiLMParams` field count in C++ |
| `--checkpoint-dir DIR` | `checkpoints/` | Set per-experiment |
| `--checkpoint-every N` | 50 | 0 to disable intermediate checkpoints |
+| `--resume [CKPT]` | — | Resume from checkpoint path; if path missing, uses latest in `--checkpoint-dir` |
+| `--edge-loss-weight F` | 0.1 | Sobel gradient loss weight alongside MSE; improves style/edge capture; 0=MSE only |
+| `--film-warmup-epochs N` | 50 | Freeze FiLM MLP for first N epochs (phase-1), then unfreeze at lr×0.1; 0=joint training |
### Architecture at startup
The model prints its parameter count:
```
-Model: enc=[4, 8] film_cond_dim=5 params=2740 (~5.4 KB f16)
+Model: enc=[4, 8] film_cond_dim=5 params=3252 (~6.4 KB f16)
```
If `params` is much higher, `--enc-channels` was changed; update C++ constants accordingly.
@@ -454,17 +459,30 @@ The final checkpoint is always written even if `--checkpoint-every 0`.
## 3. Exporting Weights
-Converts a trained `.pth` checkpoint to two raw binary files for the C++ runtime.
+Converts a trained `.pth` checkpoint to two raw binary files for the C++ runtime,
+and optionally updates the HTML tool's embedded defaults.
+**Standard export (C++ runtime only):**
```bash
cd cnn_v3/training
-python3 export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth \
+uv run export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth \
--output ../../workspaces/main/weights/
```
+**Export + update HTML tool defaults (`cnn_v3/tools/weights.js`):**
+```bash
+uv run export_cnn_v3_weights.py checkpoints/checkpoint_epoch_200.pth \
+ --output ../../workspaces/main/weights/ \
+ --html
+```
+
+`--html` base64-encodes both `.bin` files and rewrites `cnn_v3/tools/weights.js`
+so the HTML tool loads the new weights as its embedded defaults at startup.
+Use `--html-output PATH` to write to a different `weights.js` location.
+
Output files are registered in `workspaces/main/assets.txt` as:
```
-WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 3928 bytes)"
+WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 4952 bytes)"
WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP weights (f32, 3104 bytes)"
```
@@ -476,10 +494,10 @@ WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP w
|-------|-----------|-------|
| enc0 Conv(20→4,3×3)+bias | 724 | — |
| enc1 Conv(4→8,3×3)+bias | 296 | — |
-| bottleneck Conv(8→8,1×1)+bias | 72 | — |
+| bottleneck Conv(8→8,3×3,dil=2)+bias | 584 | — |
| dec1 Conv(16→4,3×3)+bias | 580 | — |
| dec0 Conv(8→4,3×3)+bias | 292 | — |
-| **Total** | **1964 f16** | **3928 bytes** |
+| **Total** | **2476 f16** | **4952 bytes** |
**`cnn_v3_film_mlp.bin`** — FiLM MLP weights as raw f32, row-major:
@@ -509,8 +527,8 @@ Checkpoint: epoch=200 loss=0.012345
enc_channels=[4, 8] film_cond_dim=5
cnn_v3_weights.bin
- 1964 f16 values → 982 u32 → 3928 bytes
- Upload via CNNv3Effect::upload_weights(queue, data, 3928)
+ 2476 f16 values → 1238 u32 → 4952 bytes
+ Upload via CNNv3Effect::upload_weights(queue, data, 4952)
cnn_v3_film_mlp.bin
L0: weight (16, 5) + bias (16,)
@@ -543,10 +561,12 @@ It owns:
```
SEQUENCE 0 0 "Scene with CNN v3"
- EFFECT + GBufferEffect prev_cnn -> gbuf_feat0 gbuf_feat1 0 60
- EFFECT + CNNv3Effect gbuf_feat0 gbuf_feat1 -> sink 0 60
+ EFFECT + GBufferEffect source -> gbuf_feat0 gbuf_feat1 0 60
+ EFFECT + CNNv3Effect gbuf_feat0 gbuf_feat1 -> sink 0 60
```
+Temporal feedback (`prev_cnn`) is wired automatically by `wire_dag()` — no explicit input needed in the `.seq` file.
+
Or direct C++:
```cpp
#include "cnn_v3/src/cnn_v3_effect.h"
@@ -636,8 +656,8 @@ Do not reference them from outside the effect unless debugging.
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
-cmake --build build -j$(nproc)
-./build/demo
+cmake --build build -j4
+./build/demo64k
```
### Expected visual output
@@ -733,13 +753,14 @@ If results drift after shader edits, verify these invariants match the Python re
## 7. HTML WebGPU Tool
-**Location:** `cnn_v3/tools/` — three files, no build step.
+**Location:** `cnn_v3/tools/` — four files, no build step.
| File | Lines | Contents |
|------|-------|----------|
-| `index.html` | 147 | HTML + CSS |
-| `shaders.js` | 252 | WGSL shader constants, weight-offset constants |
-| `tester.js` | 540 | `CNNv3Tester` class, event wiring |
+| `index.html` | 168 | HTML + CSS |
+| `shaders.js` | 312 | WGSL shader constants, weight-offset constants |
+| `tester.js` | 913 | `CNNv3Tester` class, inference pipeline, layer viz |
+| `weights.js` | 7 | Embedded default weights (base64); auto-generated by `--html` |
### Usage
@@ -750,32 +771,27 @@ python3 -m http.server 8080
# Open: http://localhost:8080/cnn_v3/tools/
```
-Or on macOS with Chrome:
+Weights are **loaded automatically at startup** from `weights.js` (embedded base64).
+If the tool is served from the repo root, it also tries to fetch the latest
+`workspaces/main/weights/*.bin` over HTTP and uses those if available.
+Use the **↺ Reload** button to re-fetch after updating weights on disk.
+
+To update the embedded defaults after a training run, use `--html` (§3):
```bash
-open -a "Google Chrome" --args --allow-file-access-from-files
-open cnn_v3/tools/index.html
+uv run export_cnn_v3_weights.py checkpoints/checkpoint.pth \
+ --output ../../workspaces/main/weights/ --html
```
### Workflow
-1. **Drop `cnn_v3_weights.bin`** onto the left "weights" drop zone.
-2. **Drop a PNG or video** onto the centre canvas → CNN runs immediately.
-3. _(Optional)_ **Drop `cnn_v3_film_mlp.bin`** → FiLM sliders become active.
-4. Adjust **beat_phase / beat_norm / audio_int / style_p0 / style_p1** sliders → reruns on change.
-5. Click layer buttons (**Feat · Enc0 · Enc1 · BN · Dec1 · Output**) in the right panel to inspect activations.
-6. **Save PNG** to export the current output.
+1. **Drop a PNG or video** onto the canvas → CNN runs immediately (weights pre-loaded).
+2. Adjust **beat_phase / beat_norm / audio_int / style_p0 / style_p1** sliders.
+3. Click layer buttons (**Feat · Enc0 · Enc1 · BN · Dec1 · Output**) to inspect activations.
+4. **Save PNG** to export the current output.
+5. _(Optional)_ Drop updated `.bin` files onto the left panel to override embedded weights.
Keyboard: `[SPACE]` toggle original · `[D]` diff×10.
-### Input files
-
-| File | Format | Notes |
-|------|--------|-------|
-| `cnn_v3_weights.bin` | raw u32 (no header) | 982 u32 = 1964 f16 = ~3.9 KB |
-| `cnn_v3_film_mlp.bin` | raw f32 | 776 f32 = 3.1 KB; optional — identity FiLM used if absent |
-
-Both produced by `export_cnn_v3_weights.py` (§3).
-
### Texture chain
| Texture | Format | Size |
@@ -801,7 +817,7 @@ all geometric channels (normal, depth, depth_grad, mat_id, prev) = 0.
### Pitfalls
- `rgba32uint` and `rgba16float` textures both need `STORAGE_BINDING | TEXTURE_BINDING` usage.
-- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672).
+- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184).
- Uniform buffer layouts must match WGSL `Params` structs exactly (padding included).
---
@@ -816,7 +832,7 @@ all geometric channels (normal, depth, depth_grad, mat_id, prev) = 0.
| `cnn_v3/training/pack_photo_sample.py` | Photo → zeroed-geometry sample directory |
| `cnn_v3/training/cnn_v3_utils.py` | Dataset class, feature assembly, channel dropout, salient-point detection |
| `cnn_v3/training/train_cnn_v3.py` | CNNv3 model definition, training loop, CLI |
-| `cnn_v3/training/export_cnn_v3_weights.py` | Checkpoint → `cnn_v3_weights.bin` + `cnn_v3_film_mlp.bin` |
+| `cnn_v3/training/export_cnn_v3_weights.py` | Checkpoint → `cnn_v3_weights.bin` + `cnn_v3_film_mlp.bin`; `--html` rewrites `weights.js` |
| `cnn_v3/training/gen_test_vectors.py` | NumPy reference forward pass + C header generator |
| `cnn_v3/test_vectors.h` | Compiled-in test vectors (auto-generated, do not edit) |
| `cnn_v3/src/cnn_v3_effect.h` | C++ class, Params structs, `CNNv3FiLMParams` API |
@@ -827,6 +843,7 @@ all geometric channels (normal, depth, depth_grad, mat_id, prev) = 0.
| `cnn_v3/tools/index.html` | HTML tool — UI shell + CSS |
| `cnn_v3/tools/shaders.js` | HTML tool — inline WGSL shaders + weight-offset constants |
| `cnn_v3/tools/tester.js` | HTML tool — CNNv3Tester class, inference pipeline, layer viz |
+| `cnn_v3/tools/weights.js` | HTML tool — embedded default weights (base64, auto-generated) |
| `cnn_v2/tools/cnn_v2_test/index.html` | HTML tool reference pattern (v2) |
---
diff --git a/cnn_v3/docs/cnn_v3_architecture.png b/cnn_v3/docs/cnn_v3_architecture.png
new file mode 100644
index 0000000..2116c2b
--- /dev/null
+++ b/cnn_v3/docs/cnn_v3_architecture.png
Binary files differ
diff --git a/cnn_v3/docs/gen_architecture_png.py b/cnn_v3/docs/gen_architecture_png.py
new file mode 100644
index 0000000..bd60a97
--- /dev/null
+++ b/cnn_v3/docs/gen_architecture_png.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["matplotlib"]
+# ///
+"""Generate CNN v3 U-Net + FiLM architecture diagram → cnn_v3_architecture.png"""
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+from matplotlib.patches import FancyBboxPatch
+from matplotlib.path import Path
+import matplotlib.patheffects as pe
+
+# ---------------------------------------------------------------------------
+# Canvas
+# ---------------------------------------------------------------------------
+BG = '#0F172A'
+fig = plt.figure(figsize=(17, 10), facecolor=BG)
+ax = fig.add_axes([0, 0, 1, 1], facecolor=BG)
+ax.set_xlim(0, 17)
+ax.set_ylim(0, 10)
+ax.axis('off')
+
+# ---------------------------------------------------------------------------
+# Palette
+# ---------------------------------------------------------------------------
+C_ENC = '#3B82F6' # encoder — blue
+C_BN = '#8B5CF6' # bottleneck — violet
+C_DEC = '#10B981' # decoder — emerald
+C_MLP = '#EC4899' # FiLM MLP — pink
+C_FILM = '#F59E0B' # FiLM γ/β arrows — amber
+C_IO = '#475569' # input/output — slate
+C_SKP = '#F97316' # skip connections — orange
+C_ARR = '#94A3B8' # main flow arrows — cool-grey
+C_TXT = '#F1F5F9' # text — near-white
+C_DIM = '#64748B' # dim labels — slate
+
+# ---------------------------------------------------------------------------
+# Geometry — two-column U layout
+# ---------------------------------------------------------------------------
+EX, DX = 3.8, 13.2 # encoder / decoder centre-x
+BX = 8.5 # bottleneck centre-x
+
+BW = 4.6 # block width (enc / dec)
+BH = 0.95 # block height (enc / dec)
+BW_BN = 5.4 # bottleneck wider
+BH_BN = 0.95
+BH_IO = 0.72
+
+# y positions (top = high number)
+Y_IN = 8.90
+Y_E0 = 7.50 # enc0 full res
+Y_E1 = 5.80 # enc1 ½ res
+Y_BN = 3.20 # bottleneck ¼ res
+Y_D1 = 5.80 # dec1 ½ res
+Y_D0 = 7.50 # dec0 full res
+Y_OUT = 8.90
+
+Y_MLP = 1.25 # FiLM MLP
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def box(cx, cy, w, h, color, line1, line2='', lfs=9.5, sfs=8.0, alpha=0.92):
+ r = FancyBboxPatch((cx - w/2, cy - h/2), w, h,
+ boxstyle='round,pad=0.10',
+ fc=color, ec='white', lw=1.3, alpha=alpha, zorder=3)
+ ax.add_patch(r)
+ dy = 0.18 if line2 else 0
+ ax.text(cx, cy + dy, line1, ha='center', va='center',
+ fontsize=lfs, fontweight='bold', color='white', zorder=4,
+ fontfamily='DejaVu Sans Mono')
+ if line2:
+ ax.text(cx, cy - 0.18, line2, ha='center', va='center',
+ fontsize=sfs, color='white', alpha=0.80, zorder=4)
+
+
+def arrow(x0, y0, x1, y1, color=C_ARR, lw=1.8, dashed=False,
+ rad=0.0, label='', lx=None, ly=None):
+ ls = (0, (5, 3)) if dashed else 'solid'
+ cs = f'arc3,rad={rad}' if rad else 'arc3,rad=0'
+ ax.annotate('', xy=(x1, y1), xytext=(x0, y0),
+ arrowprops=dict(arrowstyle='->', color=color, lw=lw,
+ linestyle=ls, mutation_scale=13,
+ connectionstyle=cs),
+ zorder=2)
+ if label:
+ ax.text(lx if lx else (x0+x1)/2,
+ ly if ly else (y0+y1)/2,
+ label, ha='center', va='center', fontsize=7.5,
+ color=color, zorder=5,
+ bbox=dict(fc=BG, ec='none', alpha=0.75,
+ boxstyle='round,pad=0.15'))
+
+
+def dim_label(x, y, txt):
+ ax.text(x, y, txt, ha='center', va='center',
+ fontsize=8.5, color=C_DIM, style='italic')
+
+
+# ---------------------------------------------------------------------------
+# Blocks
+# ---------------------------------------------------------------------------
+
+box(EX, Y_IN, BW, BH_IO, C_IO, 'G-Buffer Features',
+ '20 channels · full res')
+
+box(EX, Y_E0, BW, BH, C_ENC, 'enc0 Conv(20→4, 3×3) + FiLM + ReLU',
+ 'full res · 4 ch')
+
+box(EX, Y_E1, BW, BH, C_ENC, 'enc1 Conv(4→8, 3×3) + FiLM + ReLU',
+ '½ res · 8 ch · (AvgPool↓ on input)')
+
+box(BX, Y_BN, BW_BN, BH_BN, C_BN,
+ 'bottleneck Conv(8→8, 3×3, dilation=2) + ReLU',
+ '¼ res · 8 ch · no FiLM · effective RF ≈ 10 px @ ½res')
+
+box(DX, Y_D1, BW, BH, C_DEC, 'dec1 Conv(16→4, 3×3) + FiLM + ReLU',
+ '½ res · 4 ch · (upsample↑ + cat enc1 skip)')
+
+box(DX, Y_D0, BW, BH, C_DEC, 'dec0 Conv(8→4, 3×3) + FiLM + sigmoid',
+ 'full res · 4 ch · (upsample↑ + cat enc0 skip)')
+
+box(DX, Y_OUT, BW, BH_IO, C_IO, 'RGBA Output',
+ '4 channels · full res')
+
+box(BX, Y_MLP, 9.2, 1.10, C_MLP,
+ 'FiLM MLP Linear(5→16) → ReLU → Linear(16→40)',
+ 'in: beat_phase · beat_norm · audio_intensity · style_p0 · style_p1'
+ ' → γ/β (×2) for enc0(4) enc1(8) dec1(4) dec0(4) = 40 values',
+ sfs=7.5)
+
+# ---------------------------------------------------------------------------
+# Main-flow arrows
+# ---------------------------------------------------------------------------
+
+# Input → enc0
+arrow(EX, Y_IN - BH_IO/2 - .04, EX, Y_E0 + BH/2 + .04)
+
+# enc0 → enc1 (AvgPool label beside)
+arrow(EX, Y_E0 - BH/2 - .04, EX, Y_E1 + BH/2 + .04,
+ label='AvgPool\n 2×2', lx=EX + 0.72, ly=(Y_E0 + Y_E1)/2)
+
+# enc1 → bottleneck (curve down-right)
+arrow(EX, Y_E1 - BH/2 - .04,
+ BX - BW_BN/2 - .04, Y_BN,
+ rad=-0.28,
+ label='AvgPool\n 2×2', lx=(EX + BX)/2 - 0.5, ly=Y_BN + 0.90)
+
+# bottleneck → dec1 (curve right-up)
+arrow(BX + BW_BN/2 + .04, Y_BN,
+ DX, Y_D1 - BH/2 - .04,
+ rad=-0.28,
+ label='upsample\n 2×', lx=(BX + DX)/2 + 0.5, ly=Y_D1 - 0.90)
+
+# dec1 → dec0
+arrow(DX, Y_D1 + BH/2 + .04, DX, Y_D0 - BH/2 - .04,
+ label='upsample\n 2×', lx=DX - 0.72, ly=(Y_D1 + Y_D0)/2)
+
+# dec0 → output
+arrow(DX, Y_D0 + BH/2 + .04, DX, Y_OUT - BH_IO/2 - .04)
+
+# ---------------------------------------------------------------------------
+# Skip connections
+# ---------------------------------------------------------------------------
+
+# enc0 skip → dec0
+arrow(EX + BW/2 + .04, Y_E0,
+ DX - BW/2 - .04, Y_D0,
+ color=C_SKP, lw=1.6, dashed=True,
+ label='skip enc0 (4 ch)', ly=Y_E0 + 0.40)
+
+# enc1 skip → dec1
+arrow(EX + BW/2 + .04, Y_E1,
+ DX - BW/2 - .04, Y_D1,
+ color=C_SKP, lw=1.6, dashed=True,
+ label='skip enc1 (8 ch)', ly=Y_E1 + 0.40)
+
+# ---------------------------------------------------------------------------
+# FiLM γ/β arrows (MLP → each FiLM layer)
+# ---------------------------------------------------------------------------
+film_targets = [
+ (EX, Y_E0 - BH/2 - .04), # enc0 bottom
+ (EX, Y_E1 - BH/2 - .04), # enc1 bottom
+ (DX, Y_D1 - BH/2 - .04), # dec1 bottom
+ (DX, Y_D0 - BH/2 - .04), # dec0 bottom
+]
+for tx, ty in film_targets:
+ ax.annotate('', xy=(tx, ty),
+ xytext=(BX + (tx - BX) * 0.05, Y_MLP + 0.55 + .04),
+ arrowprops=dict(arrowstyle='->', color=C_FILM, lw=1.2,
+ linestyle=(0, (3, 3)), mutation_scale=10,
+ connectionstyle='arc3,rad=0.18'),
+ zorder=2)
+
+ax.text(8.5, 4.30, 'γ / β', ha='center', va='center',
+ fontsize=9, color=C_FILM, alpha=0.85, style='italic', zorder=5)
+
+# ---------------------------------------------------------------------------
+# Resolution markers (left margin)
+# ---------------------------------------------------------------------------
+for y, lbl in [(Y_E0, 'full res'), (Y_E1, '½ res'), (Y_BN, '¼ res')]:
+ dim_label(0.62, y, lbl)
+ ax.plot([0.95, 1.10], [y, y], color=C_DIM, lw=0.8, zorder=1)
+
+# ---------------------------------------------------------------------------
+# Legend
+# ---------------------------------------------------------------------------
+legend_items = [
+ mpatches.Patch(fc=C_ENC, ec='white', lw=0.8, label='Encoder'),
+ mpatches.Patch(fc=C_BN, ec='white', lw=0.8, label='Bottleneck'),
+ mpatches.Patch(fc=C_DEC, ec='white', lw=0.8, label='Decoder'),
+ mpatches.Patch(fc=C_MLP, ec='white', lw=0.8, label='FiLM MLP'),
+ mpatches.Patch(fc=C_IO, ec='white', lw=0.8, label='I/O'),
+ plt.Line2D([0], [0], color=C_SKP, lw=1.6, ls='--', label='Skip connection'),
+ plt.Line2D([0], [0], color=C_FILM, lw=1.2, ls=(0, (3,3)), label='FiLM γ/β'),
+]
+leg = ax.legend(handles=legend_items, loc='lower right',
+ bbox_to_anchor=(0.99, 0.01),
+ framealpha=0.15, facecolor=BG, edgecolor=C_DIM,
+ fontsize=8, labelcolor=C_TXT, ncol=1)
+
+# ---------------------------------------------------------------------------
+# Title
+# ---------------------------------------------------------------------------
+ax.text(8.5, 9.68, 'CNN v3 — U-Net + FiLM Architecture',
+ ha='center', va='center', fontsize=14, fontweight='bold', color=C_TXT)
+
+# ---------------------------------------------------------------------------
+# Save
+# ---------------------------------------------------------------------------
+import pathlib
+out = pathlib.Path(__file__).parent / 'cnn_v3_architecture.png'
+fig.savefig(out, dpi=180, bbox_inches='tight', facecolor=BG, edgecolor='none')
+print(f'Saved → {out} ({out.stat().st_size // 1024} KB)')
diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
index e24586f..e30682b 100644
--- a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
+++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl
@@ -1,17 +1,18 @@
// CNN v3 — Bottleneck
-// AvgPool2x2(enc1) + Conv(8->8, 1x1) + ReLU (no FiLM)
+// AvgPool2x2(enc1) + Conv(8->8, 3x3, dilation=2) + ReLU (no FiLM)
//
-// Input: enc1_tex (rgba32uint, 8xf16) half-res
-// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims)
+// Input: enc1_tex (rgba32uint, 8xf16) half-res
+// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims)
//
// Weight layout (f16, OIHW + bias):
-// [0 .. 8*8*1) conv: w[out][in] (1x1 kernel)
-// [64 .. +8) bias: b[out]
+// [0 .. 8*8*9) conv: w[out][in][ky*3+kx] (3x3 kernel, OIHW)
+// [576 .. +8) bias: b[out]
#include "cnn_v3/common"
-const BN_IN: u32 = 8u;
-const BN_OUT: u32 = 8u;
+const BN_IN: u32 = 8u;
+const BN_OUT: u32 = 8u;
+const BN_DILATION: i32 = 2;
struct Params {
weight_offset: u32,
@@ -24,7 +25,7 @@ struct Params {
@group(0) @binding(3) var bottleneck_out: texture_storage_2d<rgba32uint, write>;
// Avg-pool 2x2 from enc1_tex at quarter-res coord qcoord.
-// Returns zeros for OOB quarter-res coords (zero-padding for the 1x1 conv).
+// Returns zeros for OOB quarter-res coords (zero-padding for the 3x3 conv).
fn load_enc1_avg(qcoord: vec2i, half_dims: vec2i) -> array<f32, 8> {
let quart_dims = half_dims / 2;
if (qcoord.x < 0 || qcoord.y < 0 || qcoord.x >= quart_dims.x || qcoord.y >= quart_dims.y) {
@@ -50,14 +51,19 @@ fn bottleneck_main(@builtin(global_invocation_id) id: vec3u) {
let coord = vec2i(id.xy);
if (coord.x >= quart_dims.x || coord.y >= quart_dims.y) { return; }
- let wo = params.weight_offset;
- let feat = load_enc1_avg(coord, half_dims);
+ let wo = params.weight_offset;
var out: array<f32, BN_OUT>;
for (var o: u32 = 0u; o < BN_OUT; o++) {
- var sum = get_w(wo, BN_OUT * BN_IN + o); // bias (1x1 kernel: no spatial idx)
- for (var i: u32 = 0u; i < BN_IN; i++) {
- sum += get_w(wo, o * BN_IN + i) * feat[i];
+ var sum = get_w(wo, BN_OUT * BN_IN * 9u + o); // bias (at end of 3x3 conv weights)
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_enc1_avg(coord + vec2i(kx, ky) * BN_DILATION, half_dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < BN_IN; i++) {
+ sum += get_w(wo, o * BN_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
}
out[o] = max(0.0, sum);
}
diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc
index bfbb17b..1391eba 100644
--- a/cnn_v3/src/cnn_v3_effect.cc
+++ b/cnn_v3/src/cnn_v3_effect.cc
@@ -25,7 +25,7 @@
//
static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias
static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias
-static const uint32_t kBnWeights = 8 * 8 * 1 + 8; // Conv(8→8,1×1)+bias
+static const uint32_t kBnWeights = 8 * 8 * 9 + 8; // Conv(8→8,3×3,dilation=2)+bias
static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias
static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias
diff --git a/cnn_v3/test_vectors.h b/cnn_v3/test_vectors.h
index 6d1abc5..3e256a3 100644
--- a/cnn_v3/test_vectors.h
+++ b/cnn_v3/test_vectors.h
@@ -9,78 +9,78 @@ static const int kCnnV3TestH = 8;
// 256 u32 values
static const uint32_t kCnnV3TestFeat0U32[256] = {
- 0x2ccd39ebu, 0x3acb39d7u, 0x3814378fu, 0x3bc134ffu, 0x35e739ddu, 0x33073198u, 0x3b8a376cu, 0x32e339e0u,
- 0x360d3ae0u, 0x33bc3ad0u, 0x3b6a38f3u, 0x398b3420u, 0x30d23be6u, 0x39652da8u, 0x2c8f3570u, 0x3a08379bu,
- 0x355c3490u, 0x38293ac4u, 0x37243abeu, 0x39353ba0u, 0x3b152f6bu, 0x308837d4u, 0x398030e3u, 0x34962a10u,
- 0x370f3079u, 0x382d36a1u, 0x281a3479u, 0x35fb38eau, 0x2ef43936u, 0x33f2230eu, 0x364e374eu, 0x360c3a7bu,
- 0x38c0383eu, 0x381f2597u, 0x36be3584u, 0x3a432e6bu, 0x25b33b8au, 0x3a1c38d0u, 0x3a4d348fu, 0x3b6f390fu,
- 0x296c3bd9u, 0x3860371eu, 0x356130b2u, 0x24283be7u, 0x3abe373du, 0x37ad352fu, 0x37993bd3u, 0x2a9f3031u,
- 0x34413b90u, 0x2dce3808u, 0x3b7136c7u, 0x3bc53805u, 0x38093424u, 0x372c3ae0u, 0x3ad83479u, 0x383f363du,
- 0x31f83bd0u, 0x27f434d3u, 0x32683645u, 0x31cd3971u, 0x34373966u, 0x359535afu, 0x377739bcu, 0x3ad235c8u,
- 0x32d83893u, 0x357b3b33u, 0x37ea28fdu, 0x33a22fefu, 0x302f39fau, 0x3b7f3a75u, 0x39af38dau, 0x3bf139b5u,
- 0x31363577u, 0x38443827u, 0x38e831b1u, 0x3b6c233bu, 0x2910343cu, 0x33b02eeeu, 0x28333462u, 0x322d3478u,
- 0x362a360fu, 0x353f356du, 0x26742dbeu, 0x3a0e3278u, 0x3b6e3bedu, 0x38413809u, 0x3a313509u, 0x3ac13a1eu,
- 0x36f33b2au, 0x3a743a23u, 0x3b6f34efu, 0x3bf42e0au, 0x2df83a29u, 0x28603940u, 0x3a653a29u, 0x3adb38d2u,
- 0x346a2e44u, 0x296f36a0u, 0x343e372cu, 0x36cd3649u, 0x34533b09u, 0x36d13b26u, 0x3805353fu, 0x341e36afu,
- 0x30dc3805u, 0x388735a2u, 0x3a97369du, 0x3bc2341cu, 0x3bbe3a47u, 0x308c3ab5u, 0x31703836u, 0x38ac3a8cu,
- 0x3b703437u, 0x38832f5fu, 0x2b8839c5u, 0x3a8738c8u, 0x38192c52u, 0x394e3423u, 0x3b7f2f98u, 0x31f43b28u,
- 0x38b3352cu, 0x371539bfu, 0x2eaa3100u, 0x37493c00u, 0x37b83afbu, 0x2d9e3b61u, 0x3b702f4cu, 0x35093b94u,
- 0x373d35afu, 0x321536a9u, 0x340e3b30u, 0x2c4c39a4u, 0x393b28f6u, 0x393e356du, 0x3b992e04u, 0x3b0339fdu,
- 0x351f305eu, 0x384c35e5u, 0x2bc334c0u, 0x341335e7u, 0x324d362du, 0x39043431u, 0x35873636u, 0x3a2d3845u,
- 0x38b33610u, 0x382d3bbbu, 0x3a593b47u, 0x36de2b84u, 0x3be53996u, 0x2df03756u, 0x300d387fu, 0x38103a03u,
- 0x3af439cau, 0x38e63908u, 0x3abd3a09u, 0x28aa3af4u, 0x32ec3873u, 0x39303ae2u, 0x320536b9u, 0x39a1356du,
- 0x2dfd328au, 0x3a1d3b1bu, 0x34ad3265u, 0x39aa3bc7u, 0x34ec38e2u, 0x290f34c9u, 0x298739d4u, 0x39d61cf9u,
- 0x3a0d3b97u, 0x37c7378cu, 0x353236fau, 0x36e6382cu, 0x3b2f38c9u, 0x2d0a3bf6u, 0x31c83628u, 0x349935a2u,
- 0x3a1d3196u, 0x3b5b37f1u, 0x2c49282cu, 0x2d233674u, 0x3be33434u, 0x325732b0u, 0x37f83897u, 0x360738a5u,
- 0x306f3a9du, 0x398536dbu, 0x35ea3af2u, 0x2c6d388bu, 0x2c6d3173u, 0x349c39d3u, 0x2c4039cau, 0x3aaf3ae6u,
- 0x26152db1u, 0x3ad42b34u, 0x38633383u, 0x3a5d36d2u, 0x380137d3u, 0x30ce3beau, 0x2aa03aa5u, 0x3b1737a4u,
- 0x397b3952u, 0x36b23437u, 0x382c35deu, 0x353b3765u, 0x340334e3u, 0x30cc35d7u, 0x38d13afau, 0x398d3048u,
- 0x339a3ac8u, 0x206930d2u, 0x3a192a0cu, 0x29bf3be6u, 0x2c9939fcu, 0x3a0c38bdu, 0x219935bfu, 0x3bee38c3u,
- 0x3210341fu, 0x38712feeu, 0x3a5738c6u, 0x3b243a06u, 0x33ea3a72u, 0x34c23872u, 0x3b753547u, 0x3bcc3975u,
- 0x384d36acu, 0x2ede37cbu, 0x38393393u, 0x3b742c50u, 0x32562fedu, 0x2e343a1fu, 0x39ce3b34u, 0x39892c64u,
- 0x3a0f390eu, 0x39bf3aa0u, 0x352938b0u, 0x3ba83994u, 0x395138b3u, 0x3a0d36feu, 0x31223bfbu, 0x3851327au,
- 0x389337b4u, 0x36782a48u, 0x38ae38aau, 0x39c33942u, 0x3a523922u, 0x384d3900u, 0x2e7a38d9u, 0x3838345fu,
- 0x396f3afcu, 0x38bd2dc9u, 0x39df3318u, 0x38bf3a9fu, 0x356b38bbu, 0x3aea3724u, 0x382839c9u, 0x2a7335e4u,
+ 0x322d3094u, 0x3b8e35f9u, 0x3384380bu, 0x356a2ec5u, 0x36223a87u, 0x3a6c3a2eu, 0x38df2f18u, 0x366f3bacu,
+ 0x3b632a85u, 0x382d3b12u, 0x32c7386bu, 0x37fa3a8eu, 0x39772856u, 0x3aa23b9bu, 0x2ad3346fu, 0x3b7f3b86u,
+ 0x39233842u, 0x36b73767u, 0x2e312fa5u, 0x3ab13373u, 0x334130abu, 0x32e23864u, 0x38823139u, 0x390235e4u,
+ 0x30d53b4cu, 0x3b383b4fu, 0x390d3aa2u, 0x391d38f6u, 0x24383986u, 0x38af3baau, 0x36093a40u, 0x38142259u,
+ 0x36fe380fu, 0x33ff356fu, 0x36013838u, 0x31893bc2u, 0x34b4351du, 0x37fd3859u, 0x3b9a3926u, 0x398b3490u,
+ 0x34332669u, 0x27ef376bu, 0x396d38f1u, 0x382239f9u, 0x365638d5u, 0x2e662948u, 0x3bf7393du, 0x3876240cu,
+ 0x3a9d3a02u, 0x38f6385du, 0x3adf3993u, 0x3b692fd4u, 0x3ab126a9u, 0x323a2ce9u, 0x37201bfau, 0x3150355au,
+ 0x36703738u, 0x3b253a24u, 0x2ff63938u, 0x3b4e34bau, 0x36822daeu, 0x3b9b3b8au, 0x39573694u, 0x3a07374fu,
+ 0x309d280bu, 0x337138eau, 0x359f3954u, 0x3a8e3b18u, 0x3a2f37e3u, 0x37e83457u, 0x3ae33252u, 0x3b383a96u,
+ 0x3bad3b05u, 0x3b74334fu, 0x36c33892u, 0x357b387cu, 0x33b9349eu, 0x37f22d47u, 0x390b3b1au, 0x36dc382bu,
+ 0x32b2376eu, 0x32593a95u, 0x3a1439bcu, 0x3ae73899u, 0x3b0e34a8u, 0x3a6439d6u, 0x3ac53951u, 0x36b93bf2u,
+ 0x39f53a83u, 0x3b6a373fu, 0x38863650u, 0x333a2ec8u, 0x36583abau, 0x33df364eu, 0x3a7237deu, 0x2c7d3b29u,
+ 0x377a3899u, 0x372838eau, 0x378d3661u, 0x380238a8u, 0x3a8b378eu, 0x357639f7u, 0x3ad43a68u, 0x38a930e9u,
+ 0x39ea3491u, 0x395f33a4u, 0x38173415u, 0x361a3b97u, 0x3be53b02u, 0x314f3b00u, 0x281d3a8fu, 0x3af7364bu,
+ 0x38433983u, 0x3a803635u, 0x377f39adu, 0x335c3b24u, 0x39243174u, 0x33ea3bc7u, 0x307733fdu, 0x333f3ae2u,
+ 0x3bed3807u, 0x38742237u, 0x3a763819u, 0x369135afu, 0x39ed3160u, 0x30603a47u, 0x3b25364cu, 0x34c8198bu,
+ 0x35583871u, 0x375c345au, 0x383d31cfu, 0x389a39a7u, 0x3ac12df6u, 0x3a1e3199u, 0x3a4335c5u, 0x31f9329au,
+ 0x283737f4u, 0x39cb3336u, 0x2d2c3ab3u, 0x3a613b0eu, 0x39963af5u, 0x38333965u, 0x3b5a3939u, 0x350d2e6fu,
+ 0x3b8f2ca3u, 0x39673720u, 0x3bee3abbu, 0x3a65312du, 0x2a423b19u, 0x35ad3a08u, 0x381d3930u, 0x30543428u,
+ 0x2e9d2f7cu, 0x359f391au, 0x398932efu, 0x3850397fu, 0x362b3b7bu, 0x2ccf3ab0u, 0x3be839ebu, 0x38a33ac6u,
+ 0x35a73904u, 0x3a2a3970u, 0x37e13bfcu, 0x38c42bd9u, 0x33d52f9eu, 0x39d93543u, 0x314e31e2u, 0x3afc29c1u,
+ 0x291d398cu, 0x3878273eu, 0x38c63485u, 0x3b6336f4u, 0x396f349bu, 0x3ba62aebu, 0x39ea3bd9u, 0x330a3772u,
+ 0x39e43a80u, 0x3738331au, 0x3a9c3768u, 0x39253979u, 0x34543933u, 0x29d835f3u, 0x36ee3a4cu, 0x33da3703u,
+ 0x38b432b4u, 0x2c1c3371u, 0x36063a24u, 0x36e73615u, 0x35223a85u, 0x3b843a10u, 0x36e83949u, 0x375439fbu,
+ 0x383436a1u, 0x2eac3515u, 0x2fed36a3u, 0x38753691u, 0x28a33b72u, 0x375338f9u, 0x33fc2530u, 0x32f02f95u,
+ 0x366c3465u, 0x140e383bu, 0x2dfd312eu, 0x35443866u, 0x33193863u, 0x3b882634u, 0x300f2eefu, 0x3bda30b1u,
+ 0x38e238f1u, 0x2da93be5u, 0x32873bccu, 0x36b938fcu, 0x3b733625u, 0x3bfa30c6u, 0x39313611u, 0x2b5f3bbeu,
+ 0x388b3b62u, 0x30c639a3u, 0x39633844u, 0x30f6374du, 0x3ad633d0u, 0x39ac286au, 0x1faa3bffu, 0x39653127u,
+ 0x38b82baeu, 0x38b53979u, 0x399435d8u, 0x32a538c1u, 0x3b0e3881u, 0x378c3956u, 0x2d7f3525u, 0x21ba33d4u,
+ 0x331f3be5u, 0x31663a85u, 0x36b1348au, 0x3a633531u, 0x3b013ba9u, 0x3a3730eau, 0x3b4f30bcu, 0x35623825u,
+ 0x220c3106u, 0x3b5033efu, 0x3bc23a61u, 0x38bd2e73u, 0x3858341du, 0x34893521u, 0x31de3897u, 0x39353782u,
+ 0x3b72301au, 0x3a8e380cu, 0x39ae393bu, 0x3b0039bbu, 0x347438e9u, 0x38da2e5eu, 0x33b92c3fu, 0x38642bc5u,
};
// 256 u32 values
static const uint32_t kCnnV3TestFeat1U32[256] = {
- 0xc863b415u, 0x249c220fu, 0x603452c6u, 0x00000000u, 0x316a194cu, 0x291db2cbu, 0x5f96105bu, 0x00000000u,
- 0xeb343d39u, 0xf1b365e6u, 0x61b71b05u, 0x00000000u, 0x8151bb9eu, 0xfc56bec5u, 0x3c1e7c24u, 0x00000000u,
- 0xf1d859a5u, 0x1b1270e5u, 0x39d19474u, 0x00000000u, 0x569b30dcu, 0x097e59b6u, 0xd0d3b912u, 0x00000000u,
- 0xdafc8a80u, 0x6222c0d8u, 0xd61d6364u, 0x00000000u, 0xc5c2f0c4u, 0xcd28e9d7u, 0xcd7e12c4u, 0x00000000u,
- 0x92cfbc01u, 0x1c5ebffdu, 0xec699bb5u, 0x00000000u, 0x9bd12023u, 0xe6b94175u, 0xf58751d1u, 0x00000000u,
- 0x2fe9e259u, 0x66f28558u, 0x314748e3u, 0x00000000u, 0x0d0aabfcu, 0xf7666903u, 0xec5d90aau, 0x00000000u,
- 0xee86a635u, 0xe237f413u, 0xa61606fcu, 0x00000000u, 0x85ab0fd7u, 0xfdd13bdbu, 0x8d6075e2u, 0x00000000u,
- 0xa476623cu, 0x3634aa37u, 0xbf284477u, 0x00000000u, 0xd1c78653u, 0xadb3feedu, 0x7fa4408au, 0x00000000u,
- 0x32a77b6au, 0x08ac3716u, 0xa0976732u, 0x00000000u, 0xaeda1174u, 0xc5ca1e59u, 0xf353b939u, 0x00000000u,
- 0x7f53105cu, 0xd44334dfu, 0xb75edbe4u, 0x00000000u, 0x46f67512u, 0xd859d32du, 0x0da6b677u, 0x00000000u,
- 0x9950dc38u, 0xf0badec3u, 0xa8b1d193u, 0x00000000u, 0xefe357bdu, 0x0e606587u, 0x884c5ed2u, 0x00000000u,
- 0xc7d63411u, 0xa46ee9f4u, 0xe16ad66fu, 0x00000000u, 0x766cf523u, 0xaebf1396u, 0x6b75be3bu, 0x00000000u,
- 0xdf433db5u, 0x1e942c35u, 0x410dffe5u, 0x00000000u, 0x18c4cc46u, 0xb3bcd975u, 0x3b94557eu, 0x00000000u,
- 0x512fefb1u, 0xd62e1684u, 0x5c34ef2bu, 0x00000000u, 0x25554402u, 0x055e5375u, 0x3a08ec40u, 0x00000000u,
- 0xea28d1a6u, 0x8c71f892u, 0xfead5d3du, 0x00000000u, 0x3712d6e9u, 0x59fa8772u, 0x29c7e9cdu, 0x00000000u,
- 0x65fc32ecu, 0x90357e43u, 0xcee18a15u, 0x00000000u, 0x5e3b5c50u, 0xc583129du, 0xa04bf996u, 0x00000000u,
- 0x4ab43782u, 0xe9864a08u, 0x6f2ab1c6u, 0x00000000u, 0x26a77c61u, 0xf673703cu, 0xe9d6c9cfu, 0x00000000u,
- 0x0caebeeeu, 0xe709951fu, 0xf2875771u, 0x00000000u, 0xd43f1577u, 0x41477617u, 0xa19bf431u, 0x00000000u,
- 0x89ca27c9u, 0x9ec1ee6cu, 0x9dcf44adu, 0x00000000u, 0xa3a370ddu, 0x83958e74u, 0xb0c45102u, 0x00000000u,
- 0x86cfafcau, 0x04382d70u, 0x09083cf1u, 0x00000000u, 0xf5458e26u, 0xe8c4a35bu, 0x95ea20cbu, 0x00000000u,
- 0x2cb1e624u, 0xc80e252fu, 0x24aeadb9u, 0x00000000u, 0x60958ae8u, 0x5471b135u, 0x032c76bcu, 0x00000000u,
- 0xce983976u, 0x827df87du, 0x50f5f0adu, 0x00000000u, 0x81d7362fu, 0x00000e99u, 0x6fde87aeu, 0x00000000u,
- 0x85033eb4u, 0x56f7b265u, 0xd493d37cu, 0x00000000u, 0x3ff49a3cu, 0x23487a39u, 0x870d2e4fu, 0x00000000u,
- 0xe3249135u, 0x60123a68u, 0x0befa03du, 0x00000000u, 0xf84d74b5u, 0x71bd7da9u, 0x2c44f6cbu, 0x00000000u,
- 0x9d98f068u, 0x51d59a46u, 0xf0131dceu, 0x00000000u, 0x4b40fe50u, 0x8cd5b0fbu, 0x8b164f67u, 0x00000000u,
- 0x3e10a2d3u, 0x7fd0d4b7u, 0x1bec231fu, 0x00000000u, 0xa4cc2cd6u, 0xc22121ffu, 0xf33350e7u, 0x00000000u,
- 0x536659b7u, 0x49043fc2u, 0x8c7ec0d7u, 0x00000000u, 0xb1597a41u, 0xfe1228f2u, 0x066908e4u, 0x00000000u,
- 0x3d0194e7u, 0x432be415u, 0x4160b66fu, 0x00000000u, 0x76b6560au, 0xdf770ab8u, 0x07ef4642u, 0x00000000u,
- 0xd0dafe5cu, 0x9e1f95f4u, 0x9d7dbecdu, 0x00000000u, 0xada5c397u, 0x1d8b6a84u, 0xbf29cf46u, 0x00000000u,
- 0x3f858ef0u, 0x843e3a0cu, 0xad47e23fu, 0x00000000u, 0x9a9c1e18u, 0x52b851a8u, 0x65648845u, 0x00000000u,
- 0x79fca3a8u, 0x0a8f8f09u, 0xb9dde8cbu, 0x00000000u, 0x199671dfu, 0x7565be28u, 0xa7add019u, 0x00000000u,
- 0x14948e21u, 0xfedcb64du, 0x6091bc31u, 0x00000000u, 0x040bae5bu, 0xa89c3b59u, 0x8ebdcac3u, 0x00000000u,
+ 0xee7c0a1du, 0x290beb5au, 0x34aedb72u, 0x00000000u, 0x9c43a772u, 0x9ac02fbau, 0xca762320u, 0x00000000u,
+ 0xed95234bu, 0xd266c660u, 0x23e572b0u, 0x00000000u, 0x4f3e3e4cu, 0xe9f050c2u, 0x8c8848c4u, 0x00000000u,
+ 0xddf4a20bu, 0x90217921u, 0x0cbbcb9bu, 0x00000000u, 0x790f2266u, 0xd31ceb5cu, 0xa7b58b42u, 0x00000000u,
+ 0x21fdd340u, 0x35c8450eu, 0xdab84239u, 0x00000000u, 0xfaafaf58u, 0xc0bd647bu, 0x191bc271u, 0x00000000u,
+ 0x9e839693u, 0xd447d632u, 0xa3e3cd34u, 0x00000000u, 0x9816acb2u, 0x77a4c5f5u, 0x3eaeccfbu, 0x00000000u,
+ 0x47e04ba9u, 0xbee48e8du, 0x11df34c8u, 0x00000000u, 0x15a08a3cu, 0x658be5c3u, 0xc6403f48u, 0x00000000u,
+ 0xa8337739u, 0x97094582u, 0x88bce4acu, 0x00000000u, 0x1c5a2203u, 0x54f080bcu, 0x145a7a01u, 0x00000000u,
+ 0xc216a0ffu, 0xc036cf58u, 0x42127f23u, 0x00000000u, 0x4afdd8fau, 0x5144b748u, 0xe3a9493du, 0x00000000u,
+ 0x7d1010ddu, 0xc31737aeu, 0x72e658f1u, 0x00000000u, 0xb2bc988bu, 0x874068abu, 0x4752b9ecu, 0x00000000u,
+ 0xe055263eu, 0xb57d6353u, 0xc4f356bdu, 0x00000000u, 0xf2b9ce80u, 0x3faf6989u, 0x1770771eu, 0x00000000u,
+ 0x950fc854u, 0x537f6518u, 0x6f8f1b03u, 0x00000000u, 0x3c137b49u, 0x660207d5u, 0x64ac0a72u, 0x00000000u,
+ 0x59be07efu, 0xbe09834bu, 0x97b811efu, 0x00000000u, 0x7967f639u, 0x1cdaeda5u, 0x921b66a8u, 0x00000000u,
+ 0x2cce2e38u, 0x506c746au, 0x6a374c25u, 0x00000000u, 0x242b888du, 0x63b59666u, 0x4455c37cu, 0x00000000u,
+ 0xd98a0ed3u, 0xdc14021au, 0x012b5d82u, 0x00000000u, 0x9a37ff7fu, 0xa3fb2747u, 0x60c3dd9du, 0x00000000u,
+ 0x7818642eu, 0xca374746u, 0x60c22570u, 0x00000000u, 0x10804844u, 0x5f5ca629u, 0x40ff019fu, 0x00000000u,
+ 0x61fa17b2u, 0x3ae80a51u, 0x265e1089u, 0x00000000u, 0xfc40da19u, 0x20fd6d3au, 0xb4c2e06fu, 0x00000000u,
+ 0xb7b31acdu, 0x9e273818u, 0xe955351fu, 0x00000000u, 0x0146b1d6u, 0x4d3790ceu, 0x2f2ef0b7u, 0x00000000u,
+ 0x93b16f10u, 0xa2b2d58cu, 0xe5dcdf1fu, 0x00000000u, 0x61354928u, 0x3c63db78u, 0xec9da3a4u, 0x00000000u,
+ 0xac48ee35u, 0xc3c4f767u, 0x71ea1e0bu, 0x00000000u, 0x7287c339u, 0x63988fb6u, 0xbfe036acu, 0x00000000u,
+ 0x35eae594u, 0xf9b41907u, 0x2d097146u, 0x00000000u, 0x7602d6deu, 0x508a8127u, 0xa47c939bu, 0x00000000u,
+ 0xae41d19eu, 0xeb2d9aadu, 0xca0a22dbu, 0x00000000u, 0x3fa92484u, 0x34e77d30u, 0xe2f5759du, 0x00000000u,
+ 0x7ce514bbu, 0x18f8b09du, 0xd3314b39u, 0x00000000u, 0xa600b305u, 0x068bd432u, 0xc86814d2u, 0x00000000u,
+ 0x9b7cfb72u, 0x9d56d54bu, 0xdd6c8907u, 0x00000000u, 0x7edb5e71u, 0x7615827du, 0x9e0a75a4u, 0x00000000u,
+ 0x32a1e232u, 0x26d36ecdu, 0xd801ced0u, 0x00000000u, 0x372fa45eu, 0x811cb66bu, 0x45181f97u, 0x00000000u,
+ 0x3aff4aa1u, 0x9908111eu, 0xcd679c4eu, 0x00000000u, 0x71206dc3u, 0x2383b298u, 0x3e95f804u, 0x00000000u,
+ 0x2a217f2du, 0xe1ffcadau, 0x51ccb6e1u, 0x00000000u, 0x5fb9577bu, 0x122f7d23u, 0x722f227fu, 0x00000000u,
+ 0xe9f6f5f2u, 0x68e22b74u, 0xa6b7e5eeu, 0x00000000u, 0x2e93d042u, 0x2497b6f1u, 0xbb4be878u, 0x00000000u,
+ 0x10d4106bu, 0x72ce2922u, 0x511385eau, 0x00000000u, 0x04296d0bu, 0x87fd229fu, 0xf6c99a1cu, 0x00000000u,
+ 0x11b3b25eu, 0xd0d5e251u, 0x8a07a0e6u, 0x00000000u, 0xb93b2f92u, 0x18b76f8du, 0xde7cce09u, 0x00000000u,
+ 0x02ec3339u, 0xe824852au, 0xa8660512u, 0x00000000u, 0x5665b9b3u, 0x01d16dd3u, 0x9c67c9b7u, 0x00000000u,
+ 0x16622051u, 0x9bdad41eu, 0xc5ecdbb8u, 0x00000000u, 0x446dc047u, 0x3d1cea2eu, 0x38d1dcddu, 0x00000000u,
+ 0x398f04ebu, 0x1d29069eu, 0x3fec755bu, 0x00000000u, 0xa8c8d0adu, 0x4d71c198u, 0xc7ea4e97u, 0x00000000u,
};
-// 982 u32 values
-static const uint32_t kCnnV3TestWeightsU32[982] = {
+// 1238 u32 values
+static const uint32_t kCnnV3TestWeightsU32[1238] = {
0xa8b23143u, 0x2f9432e3u, 0x3491b3cbu, 0x317e3104u, 0xa79fb324u, 0x3419acf6u, 0x32322d86u, 0xb13da859u,
0xb4302831u, 0x2d0e324au, 0xad9630f5u, 0x338c3485u, 0xb1dd3158u, 0xb461a51du, 0x2f07b2a3u, 0x347d30b3u,
0xacf9aeb0u, 0xb1f6a4adu, 0xa377b31bu, 0x2e85b13eu, 0x3263a8d4u, 0xaf352fb1u, 0x31da3261u, 0xb010ac52u,
@@ -203,91 +203,123 @@ static const uint32_t kCnnV3TestWeightsU32[982] = {
0xa83f2c18u, 0xb41ca864u, 0x338c31d0u, 0xb22cb4b2u, 0x279a33c1u, 0xb1b5b2b8u, 0x30512e25u, 0x345a2ba3u,
0xafab9b4bu, 0xad64a2feu, 0xb45cb14bu, 0x300fadadu, 0xa8acb49fu, 0x2c3d2d88u, 0x31f63150u, 0xb3a03011u,
0x2bf1a3acu, 0xb464b0e3u, 0xa6eeb14fu, 0xb235aa9cu, 0x3416323bu, 0x3420b1bcu, 0x3414b4a1u, 0xb4af3457u,
- 0x3484310du, 0x348533cbu, 0xb40d27bbu, 0x2c5f32b7u, 0xaa5b2c68u, 0xb2a72984u,
+ 0x3484310du, 0x348533cbu, 0xb40d27bbu, 0x2c5f32b7u, 0xaa5b2c68u, 0xb2a72984u, 0xb414309bu, 0x32b33069u,
+ 0x1e0aa43bu, 0x3482af36u, 0xad08307au, 0xb162b23eu, 0x3440a58bu, 0xb178307fu, 0xacad32e7u, 0xb0f632c1u,
+ 0x34192c8eu, 0x2f69b0a6u, 0xb2b534aeu, 0x2eb0b3e7u, 0xb41eae27u, 0x30dfa396u, 0xae56b020u, 0x222b32a3u,
+ 0xa81e3295u, 0x2dca3459u, 0x3365b360u, 0xb2e19e98u, 0x2f34b2abu, 0xb019b458u, 0xa886b2ebu, 0x22b8aa94u,
+ 0xb47eb03bu, 0xacd92c64u, 0xb3832dd0u, 0xb0d5b4abu, 0xac11a6adu, 0xacb131f5u, 0x2b2f24adu, 0x20a6b497u,
+ 0xaa0cadf5u, 0x316eb3adu, 0xb496343fu, 0x31112bc9u, 0x3185b022u, 0x341f2d15u, 0xb465349eu, 0x2738a83bu,
+ 0xae49b2c8u, 0xb4a534aeu, 0x3294a74bu, 0xa235aec3u, 0xa3b83497u, 0xb44eb316u, 0xb07f3447u, 0xb3dc18feu,
+ 0x3421a9ddu, 0x348615eeu, 0x1996b0a1u, 0xa7f332e7u, 0x32d3b03cu, 0x24b8ac3au, 0xb2053493u, 0xb480afa0u,
+ 0xb1c2ac27u, 0xb21e2eeau, 0xb08b2eb6u, 0xadcead8fu, 0xa5253029u, 0x32c5ad53u, 0xb17f2987u, 0xae0b33afu,
+ 0x9aa3b46du, 0xb105b338u, 0xb31730bfu, 0x343231e5u, 0x300a2c17u, 0x34bb301au, 0xb279ae16u, 0x251b21e3u,
+ 0x2c58b22fu, 0x341bb4aau, 0xb46cb085u, 0xb0fdb386u, 0xb47cb057u, 0xb1e5b03du, 0xac69aca9u, 0xae9cae2fu,
+ 0xb48fb3e1u, 0x30edb1b8u, 0x341d34b6u, 0x24e3192fu, 0x3142af1fu, 0x329c3115u, 0xa90b3398u, 0x31e23120u,
+ 0x341faf5bu, 0x34bfb3cau, 0xb3cf3130u, 0xb4792e00u, 0x31bf3130u, 0x32da2bddu, 0xb04db3b8u, 0xb464aa97u,
+ 0xb082a7f4u, 0xa9c1ac1eu, 0xb0693349u, 0xa9af338fu, 0x162cae9du, 0xb0a9aa51u, 0xb2af1696u, 0x290dadb0u,
+ 0x3238aaa6u, 0x3483b0acu, 0x347d3177u, 0xb2df327eu, 0xb2562410u, 0x2a77321cu, 0x3420b08bu, 0x28e8b363u,
+ 0xb43c303eu, 0x32112b84u, 0x1f86b427u, 0x2e42b0a3u, 0x3432b352u, 0xb2073394u, 0x2abbaec9u, 0xa8673030u,
+ 0xb39ab299u, 0xa6dc34ccu, 0xa16a3327u, 0xb3ea340eu, 0x3420b369u, 0xaf1d344cu, 0xa74ead90u, 0xb1f3aa70u,
+ 0xb0bd33a6u, 0xb4282fe2u, 0x2de7b46eu, 0x2df8ae2fu, 0x3452b3cbu, 0x333930c5u, 0xaee8b2fbu, 0x25b6ad0eu,
+ 0xb438afcdu, 0xb0b6ad09u, 0xb1d2ac61u, 0x2ce0b092u, 0xadf0ac4bu, 0x31382535u, 0x2ab9aca7u, 0x22c1347au,
+ 0x31a333deu, 0xa972b43cu, 0x34ac2f9eu, 0xb3d2a665u, 0xb32c28c3u, 0x1cb730d4u, 0x3317304au, 0x2c512cf4u,
+ 0x329330e3u, 0xb4733316u, 0xb1732851u, 0x2db332ebu, 0xb1fdaa20u, 0x2fd3ae2eu, 0xb3ceb1adu, 0x31133373u,
+ 0xaffab1c4u, 0x2fff3488u, 0xaf632c3eu, 0xb46cafb7u, 0xb4633063u, 0x3068b4c1u, 0x30ed344fu, 0xa049a45bu,
+ 0xaebca8e8u, 0xa94a22acu, 0x33a52b8au, 0xb40b34c1u, 0xb221ac6eu, 0xb015adaeu, 0x3112b240u, 0x3406988fu,
+ 0xb428b47du, 0xb408ab6eu, 0x34aab08eu, 0xb1ccb197u, 0x94eb29a8u, 0xacbc2a2du, 0xb2f03246u, 0x2f49a980u,
+ 0xad023312u, 0xb4232934u, 0xb423b254u, 0xb0123060u, 0xb42a304cu, 0x327132f6u, 0xb492b3e4u, 0x32cab442u,
+ 0x276ab118u, 0x31ada9aau, 0x0e7f9ed2u, 0xb2b834b2u, 0xb44e3259u, 0x336ba2deu, 0x2f1d2e58u, 0xaa41b08bu,
+ 0x2296ad20u, 0xaea6a5cdu, 0xb0c9af78u, 0xb2b9ad2fu, 0x2bd83325u, 0x2f72b308u, 0xb10a32adu, 0xb4b8b2b5u,
+ 0x3109b459u, 0xb45f34adu, 0xb41c30c3u, 0x30eb2b13u, 0xb4b2ad68u, 0x34b72b4fu, 0xb1f6b0a7u, 0x283eb338u,
+ 0x319d2b68u, 0x338930dcu, 0xb0da31dfu, 0xafc8284bu, 0x3426ae89u, 0x348e2efcu, 0x25c0aa62u, 0xb38a9febu,
+ 0x243fb10eu, 0x3424b427u, 0xb1ccb339u, 0xb3bd3118u, 0x305533afu, 0x2f5eb424u, 0x30f12d0eu, 0x3031324du,
+ 0xaed12a9eu, 0x34632f93u, 0x2e502ab9u, 0x30eba8d4u, 0xb28534c7u, 0x260fb1b7u, 0x297fa1b9u, 0xab5ab454u,
+ 0x2a8b2a5fu, 0x303a2e0bu, 0x31932d6fu, 0x25c32ccau, 0xb3a82c14u, 0x2435b05bu, 0x2ee03329u, 0x2b16b3ddu,
+ 0x307eb158u, 0x2b2d3249u, 0xae332b04u, 0x32fea821u, 0x2211304au, 0xb451ad0fu,
};
// 256 uint16 values (raw f16 bits)
static const uint16_t kCnnV3ExpectedEnc0U16[256] = {
- 0x0000u, 0x0000u, 0x350cu, 0x3b3cu, 0x19bcu, 0x0000u, 0x0000u, 0x3d10u,
- 0x31e9u, 0x0000u, 0x35d0u, 0x39c3u, 0x0000u, 0x0000u, 0x2c6fu, 0x35fbu,
- 0x39b9u, 0x0000u, 0x0000u, 0x3538u, 0x2ebbu, 0x0000u, 0x34f8u, 0x0000u,
- 0x0000u, 0x0000u, 0x0000u, 0x3c96u, 0x0000u, 0x3029u, 0x0000u, 0x0000u,
- 0x0000u, 0x0000u, 0x0000u, 0x405au, 0x0000u, 0x367eu, 0x0000u, 0x3d2fu,
- 0x383bu, 0x0000u, 0x342cu, 0x3f97u, 0x0000u, 0x3c3cu, 0x0000u, 0x424eu,
- 0x0000u, 0x0000u, 0x0000u, 0x3a3au, 0x0000u, 0x3d8fu, 0x0000u, 0x3fd4u,
- 0x307du, 0x0000u, 0x0000u, 0x3f68u, 0x0000u, 0x0000u, 0x0000u, 0x3c81u,
- 0x0000u, 0x0000u, 0x398fu, 0x3ffeu, 0x0000u, 0x0000u, 0x0000u, 0x3ec1u,
- 0x0000u, 0x39b8u, 0x0000u, 0x3c61u, 0x0000u, 0x2e3au, 0x3699u, 0x41deu,
- 0x0000u, 0x0000u, 0x0000u, 0x3d2cu, 0x329au, 0x0000u, 0x0000u, 0x41a9u,
- 0x2d70u, 0x342fu, 0x0000u, 0x4066u, 0x2c77u, 0x0000u, 0x37b7u, 0x3842u,
- 0x2b9au, 0x0000u, 0x3655u, 0x4001u, 0x340au, 0x0000u, 0x30f5u, 0x41a5u,
- 0x0000u, 0x0000u, 0x0000u, 0x3d05u, 0x0000u, 0x0000u, 0x30a6u, 0x40a3u,
- 0x0000u, 0x0000u, 0x0000u, 0x4263u, 0x0000u, 0x0000u, 0x0000u, 0x3e62u,
- 0x0000u, 0x0000u, 0x0000u, 0x42d7u, 0x0000u, 0x0000u, 0x0000u, 0x3de8u,
- 0x0000u, 0x0000u, 0x0000u, 0x3f4du, 0x0000u, 0x38d4u, 0x3a61u, 0x3fb7u,
- 0x0000u, 0x0000u, 0x0000u, 0x404cu, 0x3811u, 0x31a4u, 0x0000u, 0x3edfu,
- 0x0000u, 0x0000u, 0x0000u, 0x3f30u, 0x0000u, 0x0000u, 0x0000u, 0x3ec7u,
- 0x27dau, 0x0000u, 0x0000u, 0x3efeu, 0x0000u, 0x3027u, 0x0000u, 0x39ceu,
- 0x28e8u, 0x0000u, 0x0000u, 0x4121u, 0x0000u, 0x0000u, 0x0000u, 0x40eeu,
- 0x3b70u, 0x3379u, 0x0000u, 0x40d3u, 0x0000u, 0x0000u, 0x0000u, 0x3d88u,
- 0x329du, 0x0000u, 0x0000u, 0x3fafu, 0x35c0u, 0x0000u, 0x374cu, 0x40ceu,
- 0x32b4u, 0x2c9au, 0x0000u, 0x4094u, 0x3105u, 0x31f4u, 0x34e9u, 0x3cd7u,
- 0x0000u, 0x0000u, 0x344bu, 0x3cd1u, 0x0000u, 0x2d13u, 0x0000u, 0x3e7eu,
- 0x0000u, 0x2eacu, 0x0000u, 0x4123u, 0x0000u, 0x36edu, 0x0000u, 0x3c69u,
- 0x0000u, 0x0000u, 0x0000u, 0x41d5u, 0x0000u, 0x36e4u, 0x0000u, 0x4049u,
- 0x0000u, 0x0000u, 0x0000u, 0x401du, 0x0000u, 0x38d1u, 0x333au, 0x3b08u,
- 0x0000u, 0x0000u, 0x0000u, 0x3d12u, 0x0000u, 0x0000u, 0x0000u, 0x3e6eu,
- 0x0000u, 0x0000u, 0x0000u, 0x4028u, 0x0000u, 0x0000u, 0x0000u, 0x3f64u,
- 0x0000u, 0x0000u, 0x0000u, 0x3e4bu, 0x2eeau, 0x393cu, 0x0000u, 0x4007u,
- 0x0000u, 0x267fu, 0x0000u, 0x3eabu, 0x35b4u, 0x38f9u, 0x0000u, 0x3e6bu,
+ 0x3c3fu, 0x0000u, 0x2aeeu, 0x3cdfu, 0x0000u, 0x0000u, 0x3a34u, 0x0000u,
+ 0x33e1u, 0x251du, 0x29e7u, 0x3dd0u, 0x0000u, 0x3996u, 0x2e7du, 0x3847u,
+ 0x259bu, 0x29a6u, 0x3a17u, 0x0000u, 0x3022u, 0x0000u, 0x3c4bu, 0x3c15u,
+ 0x0000u, 0x0000u, 0x38e0u, 0x3a98u, 0x0000u, 0x37dbu, 0x0000u, 0x0000u,
+ 0x0000u, 0x0000u, 0x0000u, 0x4027u, 0x0000u, 0x393cu, 0x0000u, 0x3c3bu,
+ 0x0000u, 0x31c4u, 0x3918u, 0x3f6fu, 0x0000u, 0x0000u, 0x0000u, 0x3c35u,
+ 0x0000u, 0x0000u, 0x0000u, 0x403eu, 0x0000u, 0x32b6u, 0x0000u, 0x4008u,
+ 0x3440u, 0x0000u, 0x0000u, 0x4003u, 0x0000u, 0x0000u, 0x0000u, 0x3d6bu,
+ 0x0000u, 0x0000u, 0x0000u, 0x4115u, 0x0000u, 0x0000u, 0x0000u, 0x3bcdu,
+ 0x30acu, 0x301eu, 0x3a8eu, 0x40e1u, 0x0000u, 0x0000u, 0x2dc0u, 0x401au,
+ 0x0000u, 0x0000u, 0x3638u, 0x3df2u, 0x0000u, 0x3c65u, 0x0000u, 0x3feau,
+ 0x2d79u, 0x0000u, 0x2e52u, 0x3f56u, 0x0000u, 0x0000u, 0x0000u, 0x3e3fu,
+ 0x34d0u, 0x0000u, 0x0000u, 0x3c46u, 0x38b0u, 0x3324u, 0x0000u, 0x4018u,
+ 0x0000u, 0x3385u, 0x0000u, 0x408du, 0x31ddu, 0x3585u, 0x40bau, 0x4009u,
+ 0x0000u, 0x2fd2u, 0x0000u, 0x4147u, 0x3baau, 0x0000u, 0x0000u, 0x3c42u,
+ 0x0000u, 0x0000u, 0x3378u, 0x3fc6u, 0x30cbu, 0x0000u, 0x3978u, 0x3440u,
+ 0x0000u, 0x0000u, 0x0000u, 0x38eeu, 0x0000u, 0x0000u, 0x0000u, 0x4117u,
+ 0x0000u, 0x0000u, 0x0000u, 0x4089u, 0x0000u, 0x3647u, 0x0000u, 0x43cfu,
+ 0x3752u, 0x2d2bu, 0x0000u, 0x3c2bu, 0x0000u, 0x3615u, 0x39cau, 0x0000u,
+ 0x0000u, 0x0000u, 0x0000u, 0x3e2du, 0x0000u, 0x0000u, 0x0000u, 0x3e18u,
+ 0x0000u, 0x0000u, 0x0000u, 0x3d99u, 0x2ca5u, 0x0000u, 0x0000u, 0x3d64u,
+ 0x0000u, 0x2b7fu, 0x0000u, 0x3f9eu, 0x0000u, 0x0000u, 0x0000u, 0x4133u,
+ 0x0000u, 0x0000u, 0x0000u, 0x3fc4u, 0x0000u, 0x0000u, 0x0000u, 0x3c91u,
+ 0x0000u, 0x2a5du, 0x0000u, 0x4166u, 0x0000u, 0x0000u, 0x0000u, 0x4089u,
+ 0x3165u, 0x0000u, 0x0000u, 0x3f6eu, 0x0000u, 0x0000u, 0x358du, 0x417fu,
+ 0x0000u, 0x356cu, 0x0000u, 0x4243u, 0x3c04u, 0x0000u, 0x0000u, 0x406bu,
+ 0x0000u, 0x315bu, 0x0000u, 0x40b7u, 0x0000u, 0x34beu, 0x0000u, 0x4108u,
+ 0x0000u, 0x390au, 0x2607u, 0x408fu, 0x0000u, 0x0000u, 0x0000u, 0x3b05u,
+ 0x3407u, 0x0000u, 0x0000u, 0x3d13u, 0x0000u, 0x33b5u, 0x0000u, 0x3dafu,
+ 0x0000u, 0x0000u, 0x0000u, 0x3d80u, 0x0000u, 0x2f2fu, 0x0000u, 0x3d4cu,
+ 0x0000u, 0x0000u, 0x0000u, 0x416eu, 0x0000u, 0x0000u, 0x0000u, 0x402au,
+ 0x0000u, 0x3b06u, 0x0000u, 0x3f77u, 0x0000u, 0x37fbu, 0x0000u, 0x4060u,
};
// kCnnV3Dec1HW = (W/2) x (H/2) = 4 x 4
// 64 uint16 values (raw f16 bits)
static const uint16_t kCnnV3ExpectedDec1U16[64] = {
- 0x0000u, 0x2692u, 0x3823u, 0x397eu, 0x0000u, 0x22dcu, 0x35dcu, 0x35f9u,
- 0x0000u, 0x3936u, 0x24b5u, 0x3434u, 0x0000u, 0x3b63u, 0x0000u, 0x32fcu,
- 0x0000u, 0x2913u, 0x3523u, 0x33d6u, 0x0000u, 0x3023u, 0x2575u, 0x0000u,
- 0x0000u, 0x39edu, 0x0000u, 0x0000u, 0x0000u, 0x3c91u, 0x0000u, 0x0000u,
- 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u,
- 0x0000u, 0x3754u, 0x0000u, 0x0000u, 0x318cu, 0x3a4du, 0x0000u, 0x0000u,
- 0x3206u, 0x32deu, 0x0000u, 0x0000u, 0x317du, 0x3437u, 0x0000u, 0x0000u,
- 0x312au, 0x357fu, 0x0000u, 0x0000u, 0x0000u, 0x39b5u, 0x0000u, 0x0000u,
+ 0x38dcu, 0x3d03u, 0x0000u, 0x39b0u, 0x3965u, 0x3dd1u, 0x30fdu, 0x3adau,
+ 0x387au, 0x3c79u, 0x3114u, 0x3c0eu, 0x0000u, 0x3a66u, 0x2ed6u, 0x3816u,
+ 0x3a16u, 0x3dbau, 0x0000u, 0x3a4du, 0x3cf6u, 0x3fccu, 0x0000u, 0x3c1cu,
+ 0x367bu, 0x3f06u, 0x0000u, 0x3b5cu, 0x0000u, 0x39ecu, 0x3660u, 0x3781u,
+ 0x3936u, 0x3accu, 0x0000u, 0x38dbu, 0x3d0fu, 0x3e45u, 0x0000u, 0x38bau,
+ 0x3905u, 0x3b8eu, 0x265du, 0x3c1eu, 0x0000u, 0x3881u, 0x2c6cu, 0x0000u,
+ 0x3905u, 0x3c23u, 0x0000u, 0x3271u, 0x3837u, 0x35e1u, 0x0000u, 0x0000u,
+ 0x3961u, 0x3c10u, 0x0000u, 0x0000u, 0x3594u, 0x3af9u, 0x382cu, 0x0000u,
};
// 256 uint16 values (raw f16 bits)
static const uint16_t kCnnV3ExpectedOutputU16[256] = {
- 0x3800u, 0x3934u, 0x3800u, 0x38aau, 0x384au, 0x3800u, 0x3800u, 0x3917u,
- 0x38d5u, 0x3800u, 0x3800u, 0x38f2u, 0x3800u, 0x38c9u, 0x3800u, 0x38d4u,
- 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x38dau, 0x3800u, 0x3800u,
- 0x3800u, 0x383eu, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u,
- 0x396du, 0x38eeu, 0x3800u, 0x3a87u, 0x3899u, 0x3800u, 0x3800u, 0x3972u,
- 0x3a4au, 0x3800u, 0x3800u, 0x3847u, 0x386du, 0x3800u, 0x3800u, 0x3a70u,
- 0x3800u, 0x381fu, 0x3800u, 0x3800u, 0x3800u, 0x3945u, 0x3800u, 0x392eu,
- 0x3800u, 0x3800u, 0x3800u, 0x3844u, 0x3800u, 0x3800u, 0x3820u, 0x3800u,
- 0x3a6du, 0x3832u, 0x3800u, 0x3ab0u, 0x3909u, 0x3800u, 0x3800u, 0x3a12u,
- 0x3873u, 0x3800u, 0x3800u, 0x39b8u, 0x3a9au, 0x3800u, 0x3800u, 0x3a41u,
- 0x3800u, 0x3800u, 0x3800u, 0x38d0u, 0x3952u, 0x3800u, 0x3800u, 0x398cu,
- 0x3800u, 0x3800u, 0x3800u, 0x3a21u, 0x3800u, 0x3800u, 0x3800u, 0x3800u,
- 0x3950u, 0x3800u, 0x3800u, 0x3abdu, 0x39ccu, 0x3800u, 0x3800u, 0x39e0u,
- 0x3800u, 0x3800u, 0x3800u, 0x3a62u, 0x38d7u, 0x3800u, 0x3800u, 0x3a23u,
- 0x3858u, 0x3800u, 0x3800u, 0x39f8u, 0x3800u, 0x3800u, 0x3800u, 0x3a01u,
- 0x38e7u, 0x3800u, 0x3800u, 0x3822u, 0x38fcu, 0x3800u, 0x3832u, 0x3800u,
- 0x3840u, 0x383au, 0x3800u, 0x3b39u, 0x390du, 0x3800u, 0x3800u, 0x399bu,
- 0x3800u, 0x3800u, 0x3800u, 0x39c2u, 0x3802u, 0x3800u, 0x3800u, 0x3a41u,
- 0x398bu, 0x3800u, 0x3800u, 0x39fau, 0x3800u, 0x3800u, 0x3800u, 0x396au,
- 0x38d3u, 0x3800u, 0x3800u, 0x3888u, 0x3909u, 0x3800u, 0x3800u, 0x3800u,
- 0x3863u, 0x3800u, 0x3800u, 0x3ae8u, 0x3a06u, 0x3800u, 0x3800u, 0x3a7du,
- 0x38c1u, 0x3800u, 0x3800u, 0x3a20u, 0x38cdu, 0x3800u, 0x3800u, 0x390cu,
- 0x3820u, 0x3800u, 0x3800u, 0x39d5u, 0x3863u, 0x3800u, 0x3800u, 0x389cu,
- 0x3800u, 0x3800u, 0x3800u, 0x38bcu, 0x3887u, 0x3800u, 0x3866u, 0x3800u,
- 0x38bbu, 0x3800u, 0x3800u, 0x3a8du, 0x394cu, 0x3800u, 0x3800u, 0x39b9u,
- 0x394au, 0x3800u, 0x3800u, 0x3977u, 0x3800u, 0x3800u, 0x3800u, 0x3906u,
- 0x3800u, 0x3800u, 0x386bu, 0x3a02u, 0x38bbu, 0x3800u, 0x3800u, 0x39d7u,
- 0x38a2u, 0x3800u, 0x3800u, 0x3800u, 0x3899u, 0x3800u, 0x3811u, 0x3800u,
- 0x3830u, 0x3800u, 0x387au, 0x3918u, 0x386au, 0x3800u, 0x38acu, 0x39f0u,
- 0x39c7u, 0x3800u, 0x38beu, 0x3988u, 0x38c3u, 0x3800u, 0x3930u, 0x39d5u,
- 0x397bu, 0x3800u, 0x3918u, 0x3a09u, 0x394cu, 0x3800u, 0x3952u, 0x3961u,
- 0x3980u, 0x3800u, 0x392eu, 0x3872u, 0x39c2u, 0x3800u, 0x3903u, 0x3800u,
+ 0x3988u, 0x391du, 0x3800u, 0x390au, 0x3800u, 0x39e6u, 0x3800u, 0x3836u,
+ 0x3959u, 0x39e8u, 0x3800u, 0x3817u, 0x38c4u, 0x39cbu, 0x3800u, 0x392au,
+ 0x3837u, 0x3961u, 0x3800u, 0x3884u, 0x38a4u, 0x391fu, 0x3800u, 0x3800u,
+ 0x3943u, 0x38e9u, 0x3800u, 0x3800u, 0x3920u, 0x397fu, 0x3800u, 0x3800u,
+ 0x3a53u, 0x3800u, 0x3800u, 0x39deu, 0x393cu, 0x3956u, 0x3800u, 0x3b15u,
+ 0x3960u, 0x383cu, 0x3800u, 0x3aa5u, 0x38b9u, 0x3966u, 0x3800u, 0x3a4bu,
+ 0x38eau, 0x392au, 0x3800u, 0x3b2fu, 0x38c2u, 0x3800u, 0x3800u, 0x3aafu,
+ 0x3a59u, 0x3879u, 0x3800u, 0x3a5bu, 0x3924u, 0x3933u, 0x3800u, 0x38c0u,
+ 0x393bu, 0x3800u, 0x3800u, 0x3a0bu, 0x38ecu, 0x385cu, 0x3800u, 0x3b25u,
+ 0x3968u, 0x384bu, 0x3800u, 0x39dbu, 0x3800u, 0x3972u, 0x3800u, 0x3b7cu,
+ 0x38b9u, 0x3800u, 0x3800u, 0x3b3fu, 0x388eu, 0x3898u, 0x3800u, 0x39d2u,
+ 0x38fau, 0x3800u, 0x3800u, 0x391eu, 0x3872u, 0x3966u, 0x3800u, 0x38c1u,
+ 0x38c5u, 0x3800u, 0x3800u, 0x3a4au, 0x3a61u, 0x3800u, 0x3800u, 0x3b9cu,
+ 0x38edu, 0x3800u, 0x3800u, 0x3b9du, 0x3844u, 0x38a2u, 0x3800u, 0x3b5au,
+ 0x3800u, 0x38edu, 0x3800u, 0x3a57u, 0x3800u, 0x3828u, 0x3800u, 0x3ad7u,
+ 0x3810u, 0x3800u, 0x3800u, 0x3aa6u, 0x38ceu, 0x38e7u, 0x3800u, 0x3800u,
+ 0x3921u, 0x3800u, 0x3800u, 0x3a61u, 0x3a11u, 0x3800u, 0x3800u, 0x3b23u,
+ 0x3994u, 0x3800u, 0x3800u, 0x3b95u, 0x3995u, 0x3800u, 0x3800u, 0x3b83u,
+ 0x38c6u, 0x3a05u, 0x3800u, 0x3b7cu, 0x3887u, 0x385au, 0x3800u, 0x3b0bu,
+ 0x38efu, 0x3800u, 0x3800u, 0x398eu, 0x39edu, 0x38d8u, 0x3800u, 0x381bu,
+ 0x3932u, 0x3800u, 0x3800u, 0x3a29u, 0x3992u, 0x3800u, 0x3800u, 0x3ac4u,
+ 0x394du, 0x3800u, 0x3800u, 0x3b3bu, 0x384bu, 0x3800u, 0x3800u, 0x3b07u,
+ 0x3991u, 0x384cu, 0x3800u, 0x3b38u, 0x392eu, 0x3834u, 0x3800u, 0x3ab9u,
+ 0x397fu, 0x3800u, 0x3800u, 0x3948u, 0x38d1u, 0x3800u, 0x3800u, 0x3825u,
+ 0x3938u, 0x3800u, 0x3800u, 0x39a1u, 0x3991u, 0x3800u, 0x3800u, 0x3ac0u,
+ 0x3998u, 0x3800u, 0x3800u, 0x3adfu, 0x3973u, 0x3800u, 0x3800u, 0x3b7bu,
+ 0x39fdu, 0x3800u, 0x3800u, 0x3b0du, 0x3991u, 0x3800u, 0x3800u, 0x3a5du,
+ 0x38b6u, 0x3800u, 0x3800u, 0x39cau, 0x38acu, 0x3840u, 0x3800u, 0x3825u,
+ 0x3813u, 0x3800u, 0x3800u, 0x398fu, 0x3800u, 0x3800u, 0x3800u, 0x3a33u,
+ 0x3800u, 0x3800u, 0x3800u, 0x398eu, 0x3845u, 0x3800u, 0x3800u, 0x3a2du,
+ 0x384fu, 0x3800u, 0x3800u, 0x3a2eu, 0x3800u, 0x3800u, 0x3800u, 0x3a3fu,
+ 0x3834u, 0x3800u, 0x3800u, 0x39ebu, 0x387eu, 0x3839u, 0x393au, 0x3989u,
};
diff --git a/cnn_v3/tools/index.html b/cnn_v3/tools/index.html
index 26fee9b..6c7b406 100644
--- a/cnn_v3/tools/index.html
+++ b/cnn_v3/tools/index.html
@@ -162,6 +162,7 @@ video{display:none}
</div>
<script src="shaders.js"></script>
+<script src="weights.js"></script>
<script src="tester.js"></script>
</body>
</html>
diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js
index 6c49864..36f53c8 100644
--- a/cnn_v3/tools/shaders.js
+++ b/cnn_v3/tools/shaders.js
@@ -1,9 +1,10 @@
'use strict';
// CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly.
-// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964
+// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184, total=2476
+// BN is now Conv(8→8, 3×3, dilation=2): 8*8*9+8=584 weights (was 72 for 1×1)
-const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672;
-const TOTAL_F16=1964, TOTAL_U32=982;
+const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1604, DEC0_OFF=2184;
+const TOTAL_F16=2476, TOTAL_U32=1238;
// Inlined helpers — prepended to shaders that need them.
const H = `
@@ -108,7 +109,7 @@ fn main(@builtin(global_invocation_id) id:vec3u){
pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7]))));
}`;
-// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM)
+// Bottleneck: AvgPool(enc1) + Conv(8→8, 3×3, dilation=2) + ReLU → rgba32uint quarter-res (no FiLM)
// Params (16 bytes): wo u32 _pad×3
const BN_SHADER=H+`
struct P{wo:u32,_a:u32,_b:u32,_c:u32}
@@ -129,10 +130,13 @@ fn avg(qc:vec2i,hd:vec2i)->array<f32,8>{
fn main(@builtin(global_invocation_id) id:vec3u){
let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy);
if(c.x>=qd.x||c.y>=qd.y){return;}
- let ft=avg(c,hd); var o:array<f32,8>;
+ var o:array<f32,8>;
for(var oc:u32=0u;oc<8u;oc++){
- var s=get_w(p.wo,64u+oc);
- for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];}
+ var s=get_w(p.wo,576u+oc);
+ for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){
+ let ft=avg(c+vec2i(kx,ky)*2,hd); let ki=u32(ky+1)*3u+u32(kx+1);
+ for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*72u+i*9u+ki)*ft[i];}
+ }}
o[oc]=max(0.,s);
}
textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])),
diff --git a/cnn_v3/tools/tester.js b/cnn_v3/tools/tester.js
index 0412cae..81c869d 100644
--- a/cnn_v3/tools/tester.js
+++ b/cnn_v3/tools/tester.js
@@ -52,29 +52,34 @@ class CNNv3Tester {
async preload() {
const base = '../../workspaces/main/weights/';
const files = [
- {url: base+'cnn_v3_weights.bin', isFilm: false},
- {url: base+'cnn_v3_film_mlp.bin', isFilm: true},
+ {url: base+'cnn_v3_weights.bin', isFilm: false, b64: CNN_V3_WEIGHTS_B64},
+ {url: base+'cnn_v3_film_mlp.bin', isFilm: true, b64: CNN_V3_FILM_MLP_B64},
];
- for (const {url, isFilm} of files) {
+ for (const {url, isFilm, b64} of files) {
+ let buf = null;
+ const name = url.split('/').pop();
try {
const r = await fetch(url);
- if (!r.ok) { this.log(`preload skip: ${url.split('/').pop()} (${r.status})`); continue; }
- const buf = await r.arrayBuffer();
- const name = url.split('/').pop();
- if (isFilm) {
- this.filmMlp = this.parseFilm(buf);
- const el = document.getElementById('fDrop');
- el.textContent = `✓ ${name}`; el.classList.add('ok');
- document.getElementById('fSt').textContent = 'FiLM MLP loaded';
- document.getElementById('fSt').style.color = '#28a745';
- } else {
- this.weightsU32 = this.parseWeights(buf); this.weightsBuffer = buf;
- if (this.weightsGPU) { this.weightsGPU.destroy(); this.weightsGPU = null; }
- const el = document.getElementById('wDrop');
- el.textContent = `✓ ${name}`; el.classList.add('ok');
- }
- this.log(`Preloaded: ${name}`);
- } catch(e) { this.log(`preload error (${url.split('/').pop()}): ${e.message}`, 'err'); }
+ if (r.ok) { buf = await r.arrayBuffer(); this.log(`Preloaded: ${name}`); }
+ } catch(_) {}
+ if (!buf) {
+ const s = atob(b64); const u = new Uint8Array(s.length);
+ for (let i = 0; i < s.length; i++) u[i] = s.charCodeAt(i);
+ buf = u.buffer;
+ this.log(`Loaded embedded: ${name}`);
+ }
+ if (isFilm) {
+ this.filmMlp = this.parseFilm(buf);
+ const el = document.getElementById('fDrop');
+ el.textContent = `✓ ${name}`; el.classList.add('ok');
+ document.getElementById('fSt').textContent = 'FiLM MLP loaded';
+ document.getElementById('fSt').style.color = '#28a745';
+ } else {
+ this.weightsU32 = this.parseWeights(buf); this.weightsBuffer = buf;
+ if (this.weightsGPU) { this.weightsGPU.destroy(); this.weightsGPU = null; }
+ const el = document.getElementById('wDrop');
+ el.textContent = `✓ ${name}`; el.classList.add('ok');
+ }
}
if (this.weightsU32) {
if (this.image || this.isVideo) this.run();
diff --git a/cnn_v3/tools/weights.js b/cnn_v3/tools/weights.js
new file mode 100644
index 0000000..dde1ed4
--- /dev/null
+++ b/cnn_v3/tools/weights.js
@@ -0,0 +1,4 @@
+'use strict';
+// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.
+const CNN_V3_WEIGHTS_B64='ias6I32xLDG5Masbdq4qIz+xrLQcshe3Ja1drluwb7crtHi38DZ8OL02eTh4Oe44HDTzN381TpwQqDCpCiP2pjipZywPL7CjNipXJc2qwiraJoetwijzphmqfCimJRgsX6tvqeuie6cRqoMpBhvSpbUiWSMFIlqrzCnHDiSiE6zJpYshR6udJMAmdRSkqHMVq6v0J68o6SL3p/mroia3I0uqEKobrMSdOqY5LAmgACqWqjch/SpcopUq6iJkJpEs6CumqH+lYqvLqjUs0ip5oKAkdqq9CTcnfamjJgSp36TmLZAsQLHTotyn27QPs1a0L7RwtOGwerXMswizdDeQOPw4BDj+OB85ADUuNhU4r7QGsAG0b7PsrP6ynrUosfuzALX4tNi3yrbNtDu2RbhRtSy4/TTgNrE1DDV2NyE2ijIMNqIyMibho5KvrSA1olmnQy31LbSm1KZPLHgpYqytLEKsuqVeHq+sWTWQObez6ze1PLI0krIhOdgsdaEKN6C2EzOBOzQ2WLePNx8wMrTRMiW4J7LMOBcvrrq/LkWtdq1vLIIgdKqAJSAvxa0SKXsoR6/QnXosIarzLJwjxa/VKmclhabVKL4lTSo7LNopdaspqzIp3Z4oqU2pNyxpqQ6fWKvCq3cmtqksmaEgF6hfoI2fc6rSJ0SoxxpYoRGfvijcqTqmEiNpLHKa6yDtKqUsgaj5IVsnopQupcSsNabjKqkq4SuGqjWIeavyItemJhvpK2WsI6snKwKruqijqY+sZzM/NsGzjDWZOTU0oKnJNmMtI7DBMSyx4C9JOZs0UbTvMl+nd7OMJYa1z6jONRQwy7d/EQir56pzMUqvhCzONFIj8rJVJZOvU60gLiuwMyxgNSwh9rPcLpiuDazUMWGs/6iENDsw1bSUKHCvsLFKL4UtSqloIxUw57DtoHEtYayELMEiLCDlqpymv6nIqhWqIq+9sfOwH7JjtLG0HrJAsLety7RKtbm0ibSysym0ebE1r4cdSzRJNYw0LzWVN6s3ZThgOOA4m6uKLSqsuCuOM1sucrK7pp+qPLF2J8+sIR2BM5uaH7ARLzGtn6eSK90npywwK1aoZKV6rC0qyxREq6asNCmZJVGo2is9qZ2VhSpLmYosHiZ0o3WWqCJ/qSAoXiY7Kk0owJ+QK+SoR6vnptEV3apbqVglNizUoiWigqZupOKdCayWKW8r1imMqvYiyyu6pOYr8imLKd4pGyyAK0ekOagyp1Mo36/ur8q0wa+Zsrmx567Er3e0FbXLtKO0Fa/1suazq611rHWupzVoNuY12ze3OLk32DgVOac5N7a1s5G0SrVpsZO0VLXJsvGxVLQ1tB20YbWGsTawL7JDsMSwtTXBN0E3AzhuOGM4GDcVOYM5hrGVLaSmV6DHMwowdLO9Lw2wtiHHp3ms1SUyKiccY6Z6qWciQhnhqO8o1CY9rNclpKwpKoYjlyDzo/WrwiQdpcCrXiqmqh4qgZ0TJvynaCzGnUCsTyWfLL+nqSpMHYkkOaTpmian9iiRmL0q4SWtJd4n9KiXqQKmrhVtKAusHZ/KKJGqISYXJxUhXCkEmKGsVSorpK0sDKxUpygmHCA+KWqquyfzqoqosabFqi2hoCsZp7or5ilzKtsmrxmCo2is0ycBJjUsvKyzrD8o9KkGq+aqtSz4qMYlcKTuKLOo4yClIAshSKzLJIKoDqCYpxwshqwSingh55sqkSQrgqa+JH4qQSVRKBId0JUYrDWdciwRIderTKW9LGasCSyuKpKnrKriqomXayzqJn6oMh5OIUwr4R1EKOSkaqbBqnUqoJ0mK9KfXiugH8ckgyyIqsKr8ic8ojklvCYcq86nKiaoKpeooqusoQ+s3Si6oQQfvaCgK88qviqtq0ShcBhLJxKstaXCqgEYTSrrnM2goK9ft+m5NCcrHicvOjDPoWMztSiasJ4lU7U/N0A2rbSlLO426i/psy60Qrgtti61kLDcsEmzRqfYNbMbVzc/sGYnTq6LrjIsIzDgrZAqQa+/sEcpE5jkLrCuya6uqGqkabBbIWunbyzBroyw9JkIMDkp7K7NryUoiyrIKC2wqCtvrUIo8q2ArdOtkCrdrgKuRLFXMGwtiC8+r78wSrCGsIwwRy7MMPUvKbCZqmWwzSkqIdEwOyWYKQCxX621LHmwDCFhK6UwERn4LFcrOS5SrWSnEa7qqMUmnC9uin+u1jBzN6stMjaAsf24abC+MJuxVbGZNIKrZDiBsM+xbzVOLZKzrzQtOsW3rqkiMUC8PKnlwZW9TDEMpLem4DBfrRwvQ61PLnKg7q4KtiAy0TTtqWg0jTFzMn01kjGFtX4vvjVZuW0qJTPVtNIwxzNzuNMlF6uWudsZgLaQuBW0qLrNJLEweDDdplgkAa5TrAgxbKa5MyoznTp4K7Ez3zhhqZMxHTT4Mvkl37LXr0KmILF+NrIsA6oCvMexUT3Pvw69sjjowffAyLEPsEgfMZOMqmkwB5iuptgv7K8KuKyxFaBZupe2P7G2u6a2ma6ELwc0lbEeL6I0LiwKLrM0+SY1O3Q5ITdVsJ+zSTN8vOe0iCZBI8cwTjFAKo0vYiyhrCOtvLB6rAgyfTb8tHoc3KuOrycy1TSSONswxrWINDOx2bfWLFavDrH8MsY3G7QppUM7JsAuOOs8o79nozyvW6+3LLcwZSLboPKrKCtbJH+ovKt6tV0q8rKWtFut2rgSNaozfrzPvIg0HTXvJLkvViTFL4ktfjDbNlsmIbUYKXc1ka/ROE23XLVmKlQxkat9MeK0MrAtN982nrgRMa43YjIXtUU276BsuD24Izj4OXMwQLSLLl0x8zI+uJ4xp6tctGg0RDBsNyM0VrpHuRO71DEmKaq8tadjIzy0lb0krDu6QTZxsBKudbkrtIc4GjiHORQ9Bjf3Nzk6PzVQOWk8o6gHIxsuXS0tMoczyTLoJoE1kDSdOC07MzO/N5w4XCihM+c4sDJBLls4EjFVrJ4y/jScMKsxc67nqHa2h7OcsvG3TbdbtpG47rNnNOA0aDBGNnQ1KjH9N1I3/LQ/uGS8brktur+6ji3mta6pir0Ruti9qLLMLla637i7tuS7PbESrfkf2TRVKBEsgTYppricvSxdqDkgOawrKqIoHqzmqOQpGStPKkqpIiQVodasTKhUKAgq9zEhJE6vRi1PKZUrfCyZLkk2Oi5BMewsW7FBsWk0JC+BrkouNq5VqZayjiqpLauwtTDnLtKzobX7IqOmkKfUKuUr+LaWsiuyHqupNLCt9KpBL6Om1zEUM+g2fLvgtvq8YblxuSy5I7gGtLK8UD0KOaA6jqwaLbo87zE+uZ6p5Tdnscu3TTZnMd80/y1Kp7q4HDFHMIWTSjKPLQgxR7ImtPC0DrPiNpMxz7mjtPazZLIvr2QuzTimLj23wjT1phOylzpkpU+04bsNuwc5n7kau52+p7T4uku5TbF0Phw+qzF5PPM70SkuP0M+pyYCKpcxFyBerTgzCK51sSg0VCi6KampeiwHIIYpZirjKU0jKqUjH/weJKAnI14rr6sapGgsHTPsM4Kw96J2NEC0NzYpNa223zYMN2s2I7TeMxStDyfENpMx6TBTsiuzDTEHrUYoJaXBrHirNi4PNQ46obXxuGOwSilyqPY4RLVetoe0dzDtsN21ArFet4+04z7PNBU30DsLuW67QDAuvCS8BbgtMfY277XitNkzaB9WLho5kzccs2+4ZTkBKyG2lzufMGW3Aj0OprcxQDi9tjukFTeYtW4vCLyhNni8lrqLPCa4OLwPNGq9IDUotpG7p6gmuZe8HDVNukq7xLzwpJK8kDDEO2M0Z7AGOJIx6r1DJHG0W7TgOcMwr7lRNQCy6yrjKqO3p6mYLs2zKCBsKl+xPqxPongsKKQuK3MsbiwUrdsrVCgqJv8s6yq5pBAiYyRnrGylCLUcsKm3K6mKrlG00zZhNxAxrTTtIjqwjikDsKI1lCyqs24h3ahiKTa3jbCRsO22ECQCr963R7LqNOqqva5jN1st9ykBOHgf+bLKmWiwFbM1pXOxCzi/q8M1ni+ssbAwwzTgMIk3kjwCPFY9MbFXo3C09bSksnK0mK5hsCesTbTVtaS07DTEpkYoCTydOlI89K0ltYeyJSymrHusIDhZNkA2/zBFMymwBLE/rBmxzLrVuTG79661ItEiNCxTsBs2cjqIMu81ma4VN4829rDzp6YwzLtavGO6AbtIMZ63C7u1LO+5w76EuQ693Le6Mlom3LT1NIqi8rJFNtKliSjXpxefyiIOKE8fnSj0KpusFSCWqCklBiQjKeMj5ayBnYYphiUiqFsw4i8CtHU0dqL2LEY1DSl4M5g2GbZosbUwyTFFMyy027SQri+u9pXuMTmvpbDYr0Gxf6iOJxyseaWKsM4jErSosJit4qsrLHmyEjA2I+MyBCkLOJI0PzNYMT41OTRSMOwx0jUIo80wcTFzpDIsAzMFtm20ALNms8mrIalPtDexMLKCNqguODZENfQukzQ7NKMJwi7CKzyzD7bWNKUrd6SGN641bzTyMd6xjRwYrM6z67P+MUKchZvgJ8+2C6pHsh+48a8PrWm1IbRdsD2wuK/Nr8eva7J1rae2t7ObLr4dca47q6Wls6b5rGoqOS68r8mxW7boIPWvwbLToXCuwbDcNXk0MTOtMwgr8ShONF8x5TF5tp+uTLZQtYevpLRutIeW+65GpysyOzY1tZenHiiZt0m1CrUkr/8uYZIMMDU1/jKrtIUlmSsnnDw21yp6L0c44DD0L5k1MTNBMJow5jH/MbYyPC96LRI2XjUyq28mvRpBrL6l2a4GJRWWS65srz2yGraAKqSyh7F8qmGhIbLmNTo0BzPWMjkurShmNN8wpzFbtn2uL7ZOtcGu1bSltEWewK/4rI0ycDaYtB2qkZ0zt4m1b7RFsykvCCxEMQw2NDPFtGoxvq7BJas2iR4iL8k30zEJL6k14jMbLnwvDi0JLaIvHy/6KuA0bTRFI0gtqyxloqitLRjaKbYk8izCNeo1LzQyOEE3NjcMNtU35DXXLwAz4jCJOAY6njTkOEQ43zA1Ns43ITfhN8A3BDjrNGs2zTUBM7w1mTT8Nh44yTc7Nl03TDcNNc4zdTXQMgU0KjO8NIUyZTQ4OXU4TzjcOPM3xDjLOEQ5wThuMhwphTHFJO4xQDNVL0kxOy8wLzgiIas2LlEvay2YHyittiW9sto0jjFhOw==';
+const CNN_V3_FILM_MLP_B64='3JR3PmW94L1BYem+rRCuvtBlqz0Wsa49ZSGRPVixHjxR64y+EqMcvtQAMjypcGE+37fhvtL8lD7HpeW++M6cPvIjYr4f1j8+OWUovpNyqr6x5VE9cBP4POZUSzxw/bg83+JnPX5qAr8zwDK+TzKQPWkdmT5fmsC+ZnXMvi8piz3DTba+DqKNviGjvb7BAyW+y7ajvUgKzbyCNrK96745vj5/p70s9GC9i6GKvKQiCr3Z4c298O8FvxLSTb5I0Qo+HuiqvcieCT6DHGy8sd7GuwdUx7s5zA68eC27u48Lv71xkLq+HtTNvsV13b5meBE+O0eAPHpPmTzOJY088l9vPDUwrjyf0MG+NX4dvseuz77AXnI93E3uPSkEYL6XE5O+K95EPqzOVL6lAeq9yk39vCEhRr4QcIi+KhVZvogH870TRLG+acdtvvqzB7/jigO//eUMP+Lzr74MwKq80fz9vHYujD6E6K2+IwUXP+3jjr4KnVQ++F1cvvz9Qb6snIq7lurPPKbmCj6ZqnA+t5s7vgDEsj6NZaM+rq89PvmRhD7JNWM+GUGHPVz1ijziwBA+hMBIPVA+trwvHM89wFALPmqfyz1bBzs+s4BAvhyRur2zA7o+ahQPvSyNfbsT5E09+WJ9Pkl/+L1LkoQ+fqxZvjvHFz4AcIe5KRgQPqFuJj5eOrQ+Sc2WPsEWi73F3BS+lgMtP0H/AD74d8O7+PNpPt7zGz+GB3e+VmggPxhtUD7bZgA/MvQ6PiZdTL4079s+MJFfvdSUmb2MaRE+PHRXPmQl0L0a2jS+EJsUPVj5r71eclY+igMBPiCd3b1qwy++HHcjvrCyBD6wM3C9yjhpPsmzBL2gpqw8xE5qPm9A/z1jkBA++eOSPe/oHT3QBTu9P7UMPgoWXT7B9W8+Bv9MvrdjSTxANMS7xJgzPvKzlz3uup09gGQTPuom+z0AxP09XsYNvmSFbLxv+3U+1wZzvlnPCb45cGU9ujnsviDN/zwg8+69KMKqPRl+hz6ddUG+6VukvfT1ID5rWR++w0nmvd7YI7+J4RO9T5ZuvQWiCr43WRO/JIsSPkaEJL/oZAS9wH0qv+pOLT7qoni+M9oSv4Dx4TuAkyW7mEFYPZhRib3QUqQ9hPVJvrCL5TwA3gC97EwGPhSV8L3k4x2+gMNXPJaPD77Q5y29pCWUvaS0r71+GIs9ECCDPpNcYjwyNvU8fwwRv8zjb74fQzE+/BFVvg3a/L5xX0Y++qYCv3ALgb1e4SS/QA19vh8MzL3fI4q+WBK9PZaOI75gG0i+TBvOvcINGT7gwTE83Ep8vg17e760kWs+UodWPpZrdD5Ms+49sENrPsYZKj76mFK+GAtyu1AEUz7baU0+gWfbPetWFTxZCoC9SA1YvnoDlj2HIvu91GgMvhyvxz1W+D4+GApiPabLO7wA5Ko644UCvgxTGz3LnI27LQgRvrnhw7y/McS9rXb0PkKb072Wi0a+F9PdPba+Az/hZR6+/8n4PnKPfj6Vx+c+8EacPIonVb4HE9k+PWtqvRQXJ760rxc9XHonvrSjiL7Y6Wk+X78oPXDCGT67Nny++E5iPXdDHT1sqDi+B4ItvoAH3r2zzdq85XhePIZolj5J0hY+s/isPaatrb7AdNs+8r+2vd8pDD6qzMM9eouUPsOeMz3BcII+FvpZPiF1pj6WdUG+j/devalgoTypoga9c0LEPc3fur2xte29np00viIJ5bzvYS49eWp6viSBDr6IYUe9SiOpvoBpdr3doXK+wDKou59EljxBZdC9oQC+PFTUWz0YBpY+T+5gvDaGIT+RZBs84iqJPiO15z7wWw8/nFHsvdvGID+Q2FE9OojpPqxLqD02EkS+dbfNPtnUlz2Ukgi+AiD9PPMD+z2V2i49j6osvt2EcD5sotK9a/FvvWeJ1j2C5vM9cPZ0vWTyQz72hWM+X9dMPa7bQrwyUZQ9GB83PfxbHD4g9FU+dX8XPjALx7yEcb+9up8MvhaPkD3QiYc91xN6vR4cT76kCnu+Ws0VvmA9W7yv4G89ZvsgPiVaor2ok/k9BIWJPZ5+7rno6QM+fbTrvKHqZj3NOOy8MI0HPXi5VL7iUmQ+KlTCvRSdSb7/IpU8fpXYvZHMgT2K2Us+UShZPo5z/70zGWA897+RvtLUiL4dUoi9n2wGvjaqzzznj/Q9UE0DvtqEub2WgXu++SI3PoQMk76Qa7E9PZeHvEyiXD5pKze+zpu/PepKg75irr286FJiPhIvw7tmUfk9nmWhPvCF5b2dbwU+KgwYPr6Egr5wVR6+Q0mBPoJx37xmKsg7nEbsvR+DJb0fjcw8ZlI7vTQInDw5elu+z1rWvRbWw72wLds8Qa3xvciA37008EG8EFEcvpXUKz7zSxw9tTgLPtwUgL2By/Q+Zv2+PQGEa75hnzk+mOKSPumBOj7TR6k9EBtfPuuk5z4kzPo95mqtvduNIz5nbhQ8WXzxPfMPXj6bNCU+43hzPmlAxr1LbM298vepPXV0CTyGzA4+uuxMvQIuXr4wo4Y+gG/9O73WCT6+VY29QlN9Pcs5Pj60l129snY8vjA9jD6cniU+2xM+vuAR9zshZ9Q9DufPPXYA3T3mliM+hMWzvZxOX76a5wQ+z5MvvT7p0r1+YLA98dwXPoI7ED5IrOy+V2TKPfAEEr7pdiy+rWIkvu7Z9L2RqRG+QBKju/A1Fr6QulK+rIGmvkLYHT6WOQa+4hGuvRn3o7u3Xy490raePrChM75vM20+B3g0vWQsMz5wWe69eI1IPgDCszoM0eA9aChGvvWVtb2gl/49XO2APacEGT57JqQ8mm7yPQmfij7BtZ4997PLPU49RT6skDs+6cy5va3cT72YRT69OqFpPpBEET1Ba/q9pWadPG+tVj7+1T8+0OgJvqrNSDy4bZE+90RAPljRlT6BR7Q+SKUhPoFsMz0BgVU9eDulve35Oz6Aksi7mSxxvWnAFD44KjU8FoCRvlp8Mb4+d/Q4mNqnPce5yTyQvcs+hfypPc51Ar6/P+Q92RvPPQBLmjuyIsG+eEwnvikEjT7Dtq6+O61Fvu3Ohrz7FRU+FzGvvYL5nj5chL+9eQcRvsfvhz0de4k+lwiNvuAUMj7GkTg+5JKKPeASCT7EhLy7Q++5PVd2o70jiaO9HvEzPZ/UK74CcQE/7tV3vnItBbyarB8+7oEcPvcT/L0UN/8+5ODOPZ/Tpz7gIQG+O1sZPpgroj4sATw9zVYjvvjOVD59ezc+9p8aviKffD1qmzA8R7zGvaI9rr0L/6g9hefZvTgVY77bpta9LMzuvcA/eD43miY+nTOWOa53Nb684CC+rLrmvXRhVT4IuxC+DNO0PYknQr27t/U9Y0PvPRpzCz7kGdQ9ql7BPcBAiLto9qe9HrjlvXb5T7335Zy9vFgju8NHVb5fvWo+GRQnvoaXXD3Vtiw9nonVPeokEbyDXYc+YLW7PJdgAz6isls+Qk7GveS5xr1Gits+cohJvXRoeTuwqR2+3wgHP+tNBr0B4GG9HkO3PhHJaj8QvZU9lbHfPmwI7j1ZgGI/WPXrvZutJz21hbI+N+YEPezs6bvmUFY+dfZbPd/Rnz7izxK+GquBPbJHZb0HPzs+jrptPVg8Aj8wAVi9N6jgPv5Jbr6MM00+MqIxPrwAZz0oERA+fEoLvudJH71ysgY/vhOYPaEPDr7asVQ+D8quPqWcmj1ewQg/Rqc7vrgRtz5gSTc+/Km0vEqH4j3BZ6M97evKvTmcJb5Htp492/KfPsn4vj2lx069ZroOPqXGXz74XFY7cRZWPiJgdz56er4+MAZfvg6wE71r3yc7uwqFPs7Ny71CA4o+k+GLPs7aej2b4789izC8vfOUjz5D2OU+RhUgvhVKgT5ak20+yOWBPjh0Q73A3qS9qYawvbt4YT7O1dI+mQoQP0LcDr6YwiY+yi7vvVpA8b78DAi+4jt7v5U5P75FGvu9lpeTPnhZ/r7zTIg+Hx+Avj7xOD/oW5o+D35evsYRHr5BbGi+3fQnPX3iyT2Ob3s+T6bzPUKchD5O6d6+IxjNPjWTW7wa0vk+n14dvhmT5z5tW/Y+8yOQvoRtZD6JwbU93pZLP1mnjz4C1bs9VUXBPqtl7z0=';
diff --git a/cnn_v3/training/cnn_v3_utils.py b/cnn_v3/training/cnn_v3_utils.py
index bef4091..68c0798 100644
--- a/cnn_v3/training/cnn_v3_utils.py
+++ b/cnn_v3/training/cnn_v3_utils.py
@@ -128,10 +128,11 @@ def _upsample_nearest(a: np.ndarray, h: int, w: int) -> np.ndarray:
def assemble_features(albedo: np.ndarray, normal: np.ndarray,
depth: np.ndarray, matid: np.ndarray,
- shadow: np.ndarray, transp: np.ndarray) -> np.ndarray:
+ shadow: np.ndarray, transp: np.ndarray,
+ prev: np.ndarray | None = None) -> np.ndarray:
"""Build (H,W,20) f32 feature tensor.
- prev set to zero (no temporal history during training).
+ prev: (H,W,3) f32 [0,1] previous frame RGB, or None → zeros.
mip1/mip2 computed from albedo. depth_grad computed via finite diff.
dif (ch18) = max(0, dot(oct_decode(normal), KEY_LIGHT)) * shadow.
"""
@@ -140,7 +141,8 @@ def assemble_features(albedo: np.ndarray, normal: np.ndarray,
mip1 = _upsample_nearest(pyrdown(albedo), h, w)
mip2 = _upsample_nearest(pyrdown(pyrdown(albedo)), h, w)
dgrad = depth_gradient(depth)
- prev = np.zeros((h, w, 3), dtype=np.float32)
+ if prev is None:
+ prev = np.zeros((h, w, 3), dtype=np.float32)
nor3 = oct_decode(normal)
diffuse = np.maximum(0.0, (nor3 * _KEY_LIGHT).sum(-1))
dif = diffuse * shadow
@@ -286,7 +288,8 @@ class CNNv3Dataset(Dataset):
channel_dropout_p: float = 0.3,
detector: str = 'harris',
augment: bool = True,
- patch_search_window: int = 0):
+ patch_search_window: int = 0,
+ single_sample: str = ''):
self.patch_size = patch_size
self.patches_per_image = patches_per_image
self.image_size = image_size
@@ -296,16 +299,18 @@ class CNNv3Dataset(Dataset):
self.augment = augment
self.patch_search_window = patch_search_window
- root = Path(dataset_dir)
- subdir = 'full' if input_mode == 'full' else 'simple'
- search_dir = root / subdir
- if not search_dir.exists():
- search_dir = root
-
- self.samples = sorted([
- d for d in search_dir.iterdir()
- if d.is_dir() and (d / 'albedo.png').exists()
- ])
+ if single_sample:
+ self.samples = [Path(single_sample)]
+ else:
+ root = Path(dataset_dir)
+ subdir = 'full' if input_mode == 'full' else 'simple'
+ search_dir = root / subdir
+ if not search_dir.exists():
+ search_dir = root
+ self.samples = sorted([
+ d for d in search_dir.iterdir()
+ if d.is_dir() and (d / 'albedo.png').exists()
+ ])
if not self.samples:
raise RuntimeError(f"No samples found in {search_dir}")
@@ -345,11 +350,13 @@ class CNNv3Dataset(Dataset):
shadow = load_gray(sd / 'shadow.png')
transp = load_gray(sd / 'transp.png')
h, w = albedo.shape[:2]
+ prev_path = sd / 'prev.png'
+ prev = load_rgb(prev_path) if prev_path.exists() else None
target_img = Image.open(sd / 'target.png').convert('RGBA')
if target_img.size != (w, h):
target_img = target_img.resize((w, h), Image.LANCZOS)
target = np.asarray(target_img, dtype=np.float32) / 255.0
- return albedo, normal, depth, matid, shadow, transp, target
+ return albedo, normal, depth, matid, shadow, transp, prev, target
def __getitem__(self, idx):
if self.full_image:
@@ -357,7 +364,7 @@ class CNNv3Dataset(Dataset):
else:
sample_idx = idx // self.patches_per_image
- albedo, normal, depth, matid, shadow, transp, target = self._cache[sample_idx]
+ albedo, normal, depth, matid, shadow, transp, prev, target = self._cache[sample_idx]
h, w = albedo.shape[:2]
if self.full_image:
@@ -379,6 +386,8 @@ class CNNv3Dataset(Dataset):
matid = _resize_gray(matid)
shadow = _resize_gray(shadow)
transp = _resize_gray(transp)
+ if prev is not None:
+ prev = _resize_img(prev)
target = _resize_img(target)
else:
ps = self.patch_size
@@ -395,6 +404,8 @@ class CNNv3Dataset(Dataset):
matid = matid[sl]
shadow = shadow[sl]
transp = transp[sl]
+ if prev is not None:
+ prev = prev[sl]
# Apply cached target offset (if search was enabled at init).
if self._target_offsets:
@@ -405,7 +416,7 @@ class CNNv3Dataset(Dataset):
else:
target = target[sl]
- feat = assemble_features(albedo, normal, depth, matid, shadow, transp)
+ feat = assemble_features(albedo, normal, depth, matid, shadow, transp, prev)
if self.augment:
feat = apply_channel_dropout(feat,
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py
index 99f3a81..78f5f25 100644
--- a/cnn_v3/training/export_cnn_v3_weights.py
+++ b/cnn_v3/training/export_cnn_v3_weights.py
@@ -15,8 +15,8 @@ Outputs
<output_dir>/cnn_v3_weights.bin
Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32.
Matches the format expected by CNNv3Effect::upload_weights().
- Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292)
- = 1964 f16 values = 982 u32 = 3928 bytes.
+ Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292)
+ = 2476 f16 values = 1238 u32 = 4952 bytes.
<output_dir>/cnn_v3_film_mlp.bin
FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40).
@@ -31,6 +31,7 @@ Usage
"""
import argparse
+import base64
import struct
import sys
from pathlib import Path
@@ -47,13 +48,13 @@ from train_cnn_v3 import CNNv3
# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …)
# cnn_v3/training/gen_test_vectors.py (same constants)
# ---------------------------------------------------------------------------
-ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
-ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
-BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72
-DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
-DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
+ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724
+ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296
+BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584
+DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580
+DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292
TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS
-# = 1964
+# = 2476
def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray:
@@ -158,13 +159,40 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None:
print(f"\nDone → {out}/")
+_WEIGHTS_JS_DEFAULT = Path(__file__).parent.parent / 'tools' / 'weights.js'
+
+
+def update_weights_js(weights_bin: Path, film_mlp_bin: Path,
+ js_path: Path = _WEIGHTS_JS_DEFAULT) -> None:
+ """Encode both .bin files as base64 and write cnn_v3/tools/weights.js."""
+ w_b64 = base64.b64encode(weights_bin.read_bytes()).decode('ascii')
+ f_b64 = base64.b64encode(film_mlp_bin.read_bytes()).decode('ascii')
+ js_path.write_text(
+ "'use strict';\n"
+ "// Auto-generated by export_cnn_v3_weights.py --html — do not edit by hand.\n"
+ f"const CNN_V3_WEIGHTS_B64='{w_b64}';\n"
+ f"const CNN_V3_FILM_MLP_B64='{f_b64}';\n"
+ )
+ print(f"\nweights.js → {js_path}")
+ print(f" CNN_V3_WEIGHTS_B64 {len(w_b64)} chars ({weights_bin.stat().st_size} bytes)")
+ print(f" CNN_V3_FILM_MLP_B64 {len(f_b64)} chars ({film_mlp_bin.stat().st_size} bytes)")
+
+
def main() -> None:
p = argparse.ArgumentParser(description='Export CNN v3 trained weights to .bin')
p.add_argument('checkpoint', help='Path to .pth checkpoint file')
p.add_argument('--output', default='export',
help='Output directory (default: export/)')
+ p.add_argument('--html', action='store_true',
+ help=f'Also update {_WEIGHTS_JS_DEFAULT} with base64-encoded weights')
+ p.add_argument('--html-output', default=None, metavar='PATH',
+ help='Override default weights.js path (implies --html)')
args = p.parse_args()
export_weights(args.checkpoint, args.output)
+ if args.html or args.html_output:
+ out = Path(args.output)
+ js_path = Path(args.html_output) if args.html_output else _WEIGHTS_JS_DEFAULT
+ update_weights_js(out / 'cnn_v3_weights.bin', out / 'cnn_v3_film_mlp.bin', js_path)
if __name__ == '__main__':
diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py
index 640971c..2eb889c 100644
--- a/cnn_v3/training/gen_test_vectors.py
+++ b/cnn_v3/training/gen_test_vectors.py
@@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4
ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724
ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296
-BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72
+BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2)
DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580
DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292
@@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS
BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS
DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS
DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS
-TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check
-# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount
-# ENC0: 20*4*9 + 4 = 720+4 = 724
-# ENC1: 4*8*9 + 8 = 288+8 = 296
-# BN: 8*8*1 + 8 = 64+8 = 72
-# DEC1: 16*4*9 + 4 = 576+4 = 580
-# DEC0: 8*4*9 + 4 = 288+4 = 292
-# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck.
-# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4
-# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o<DEC1_OUT, i<DEC1_IN
-# So total conv weights = DEC1_OUT * DEC1_IN * 9 = 4*16*9 = 576, bias = 4
-# Total DEC1 = 580. OK that's right.
-# Let me add: 724+296+72+580+292 = 1964. But HOWTO says 2064?
-# DEC1: Conv(16->4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK.
-# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964.
-# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table:
-# enc0: 20*4*9=720 +4 = 724
-# enc1: 4*8*9=288 +8 = 296
-# bottleneck: 8*8*1=64 +8 = 72
-# dec1: 16*4*9=576 +4 = 580
-# dec0: 8*4*9=288 +4 = 292
-# Total = 724+296+72+580+292 = 1964
-# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc.
-# I'll use the correct value derived from the formulas: 1964.
+TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS
+# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72)
# ---------------------------------------------------------------------------
# Helpers
@@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi):
def bottleneck_forward(enc1, w):
"""
- AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU
+ AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU
→ rgba32uint (f16, quarter-res). No FiLM.
enc1: (hH, hW, 8) f32 — half-res
+ Matches cnn_v3_bottleneck.wgsl exactly.
"""
hH, hW = enc1.shape[:2]
qH, qW = hH // 2, hW // 2
wo = BN_OFFSET
- # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL)
- avg = np.zeros((qH, qW, BN_IN), dtype=np.float32)
- for qy in range(qH):
- for qx in range(qW):
- s = np.zeros(BN_IN, dtype=np.float32)
- for dy in range(2):
- for dx in range(2):
- hy = min(qy * 2 + dy, hH - 1)
- hx = min(qx * 2 + dx, hW - 1)
- s += enc1[hy, hx, :]
- avg[qy, qx, :] = s * 0.25
+ def load_enc1_avg(qy, qx):
+ """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL)."""
+ if qy < 0 or qx < 0 or qy >= qH or qx >= qW:
+ return np.zeros(BN_IN, dtype=np.float32)
+ s = np.zeros(BN_IN, dtype=np.float32)
+ for dy in range(2):
+ for dx in range(2):
+ hy = min(qy * 2 + dy, hH - 1)
+ hx = min(qx * 2 + dx, hW - 1)
+ s += enc1[hy, hx, :]
+ return s * 0.25
- # 1x1 conv (no spatial loop, just channel dot-product)
+ # 3x3 conv with dilation=2 in quarter-res space
out = np.zeros((qH, qW, BN_OUT), dtype=np.float32)
for o in range(BN_OUT):
- bias = get_w(w, wo, BN_OUT * BN_IN + o)
- s = np.full((qH, qW), bias, dtype=np.float32)
- for i in range(BN_IN):
- wv = get_w(w, wo, o * BN_IN + i)
- s += wv * avg[:, :, i]
- out[:, :, o] = np.maximum(0.0, s)
+ bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o)
+ for qy in range(qH):
+ for qx in range(qW):
+ s = bias
+ for ky in range(-1, 2):
+ for kx in range(-1, 2):
+ feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2
+ ki = (ky + 1) * 3 + (kx + 1)
+ for i in range(BN_IN):
+ s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i]
+ out[qy, qx, o] = max(0.0, s)
return np.float16(out).astype(np.float32) # pack2x16float boundary
diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py
new file mode 100644
index 0000000..ca1c72a
--- /dev/null
+++ b/cnn_v3/training/infer_cnn_v3.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["torch", "numpy", "pillow", "opencv-python"]
+# ///
+"""CNN v3 PyTorch inference — compare with cnn_test (WGSL/GPU output).
+
+Simple mode (single PNG): albedo = photo, geometry channels zeroed.
+Full mode (sample dir): loads all G-buffer files via assemble_features.
+
+Usage:
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint checkpoints/ckpt.pth
+ python3 infer_cnn_v3.py sample_000/ out.png --checkpoint ckpt.pth
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --identity-film
+ python3 infer_cnn_v3.py photo.png out.png --checkpoint ckpt.pth --cond 0.5 0.0 0.8 0.0 0.0
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+sys.path.insert(0, str(Path(__file__).parent))
+from train_cnn_v3 import CNNv3
+from cnn_v3_utils import assemble_features, load_rgb, load_rg, load_depth16, load_gray
+
+
+# ---------------------------------------------------------------------------
+# Feature loading
+# ---------------------------------------------------------------------------
+
+def load_sample_dir(sample_dir: Path) -> np.ndarray:
+ """Load all G-buffer files from a sample directory → (H,W,20) f32."""
+ return assemble_features(
+ load_rgb(sample_dir / 'albedo.png'),
+ load_rg(sample_dir / 'normal.png'),
+ load_depth16(sample_dir / 'depth.png'),
+ load_gray(sample_dir / 'matid.png'),
+ load_gray(sample_dir / 'shadow.png'),
+ load_gray(sample_dir / 'transp.png'),
+ )
+
+
+def load_simple(image_path: Path) -> np.ndarray:
+ """Photo → (H,W,20) f32 with geometry channels zeroed.
+
+ normal=(0.5,0.5) is the oct-encoded "no normal" (decodes to ~(0,0,1)).
+ shadow=1.0 (fully lit), transp=0.0 (opaque).
+ """
+ albedo = load_rgb(image_path)
+ h, w = albedo.shape[:2]
+ normal = np.full((h, w, 2), 0.5, dtype=np.float32)
+ depth = np.zeros((h, w), dtype=np.float32)
+ matid = np.zeros((h, w), dtype=np.float32)
+ shadow = np.ones((h, w), dtype=np.float32)
+ transp = np.zeros((h, w), dtype=np.float32)
+ return assemble_features(albedo, normal, depth, matid, shadow, transp)
+
+
+# ---------------------------------------------------------------------------
+# Inference
+# ---------------------------------------------------------------------------
+
+def pad_to_multiple(feat: np.ndarray, m: int = 4) -> tuple:
+ """Pad (H,W,C) so H and W are multiples of m. Returns (padded, (ph, pw))."""
+ h, w = feat.shape[:2]
+ ph = (m - h % m) % m
+ pw = (m - w % m) % m
+ if ph == 0 and pw == 0:
+ return feat, (0, 0)
+ return np.pad(feat, ((0, ph), (0, pw), (0, 0))), (ph, pw)
+
+
+def run_identity_film(model: CNNv3, feat: torch.Tensor) -> torch.Tensor:
+ """Forward with identity FiLM (γ=1, β=0). Matches C++ cnn_test default."""
+ c0, c1 = model.enc_channels
+ B = feat.shape[0]
+ dev = feat.device
+
+ skip0 = F.relu(model.enc0(feat))
+
+ x = F.avg_pool2d(skip0, 2)
+ skip1 = F.relu(model.enc1(x))
+
+ x = F.relu(model.bottleneck(F.avg_pool2d(skip1, 2)))
+
+ x = F.relu(model.dec1(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip1], dim=1)
+ ))
+
+ x = F.relu(model.dec0(
+ torch.cat([F.interpolate(x, scale_factor=2, mode='nearest'), skip0], dim=1)
+ ))
+
+ return torch.sigmoid(x)
+
+
+# ---------------------------------------------------------------------------
+# Output helpers
+# ---------------------------------------------------------------------------
+
+def save_png(path: Path, out: np.ndarray) -> None:
+ """Save (H,W,4) f32 [0,1] RGBA as PNG."""
+ rgba8 = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
+ Image.fromarray(rgba8, 'RGBA').save(path)
+
+
+def print_debug_hex(out: np.ndarray, n: int = 8) -> None:
+ """Print first n pixels as hex RGBA + float values."""
+ flat = out.reshape(-1, 4)
+ for i in range(min(n, flat.shape[0])):
+ r, g, b, a = flat[i]
+ ri, gi, bi, ai = int(r*255+.5), int(g*255+.5), int(b*255+.5), int(a*255+.5)
+ print(f' [{i}] 0x{ri:02X}{gi:02X}{bi:02X}{ai:02X}'
+ f' ({r:.4f} {g:.4f} {b:.4f} {a:.4f})')
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='CNN v3 PyTorch inference')
+ p.add_argument('input', help='Input PNG or sample directory')
+ p.add_argument('output', help='Output PNG')
+ p.add_argument('--checkpoint', '-c', metavar='CKPT',
+ help='Path to .pth checkpoint (auto-finds latest if omitted)')
+ p.add_argument('--enc-channels', default='4,8',
+ help='Encoder channels (default: 4,8 — must match checkpoint)')
+ p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5,
+ help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)')
+ p.add_argument('--identity-film', action='store_true',
+ help='Bypass FiLM MLP, use γ=1 β=0 (matches C++ cnn_test default)')
+ p.add_argument('--blend', type=float, default=1.0,
+ help='Blend with input albedo: 0=input 1=CNN (default 1.0)')
+ p.add_argument('--debug-hex', action='store_true',
+ help='Print first 8 output pixels as hex')
+ args = p.parse_args()
+
+ # --- Feature loading ---
+ inp = Path(args.input)
+ if inp.is_dir():
+ print(f'Mode: full ({inp})')
+ feat = load_sample_dir(inp)
+ albedo_rgb = load_rgb(inp / 'albedo.png')
+ else:
+ print(f'Mode: simple ({inp})')
+ feat = load_simple(inp)
+ albedo_rgb = load_rgb(inp)
+ orig_h, orig_w = feat.shape[:2]
+
+ feat_padded, (ph, pw) = pad_to_multiple(feat, 4)
+ H, W = feat_padded.shape[:2]
+ if ph or pw:
+ print(f'Padded {orig_w}×{orig_h} → {W}×{H}')
+ else:
+ print(f'Resolution: {W}×{H}')
+
+ # --- Load checkpoint ---
+ if args.checkpoint:
+ ckpt_path = Path(args.checkpoint)
+ else:
+ ckpts = sorted(Path('checkpoints').glob('checkpoint_epoch_*.pth'),
+ key=lambda f: int(f.stem.split('_')[-1]))
+ if not ckpts:
+ print('Error: no checkpoint found; use --checkpoint', file=sys.stderr)
+ sys.exit(1)
+ ckpt_path = ckpts[-1]
+ print(f'Checkpoint: {ckpt_path}')
+
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
+ cfg = ckpt.get('config', {})
+ enc_channels = cfg.get('enc_channels', [int(c) for c in args.enc_channels.split(',')])
+ film_cond_dim = cfg.get('film_cond_dim', 5)
+ print(f'Architecture: enc={enc_channels} film_cond_dim={film_cond_dim}')
+
+ model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim)
+ model.load_state_dict(ckpt['model_state_dict'])
+ model.eval()
+
+ # --- Inference ---
+ feat_t = torch.from_numpy(feat_padded).permute(2, 0, 1).unsqueeze(0) # (1,20,H,W)
+ cond_t = torch.tensor([args.cond], dtype=torch.float32) # (1,5)
+
+ with torch.no_grad():
+ if args.identity_film:
+ print('FiLM: identity (γ=1, β=0)')
+ out_t = run_identity_film(model, feat_t)
+ else:
+ print(f'FiLM cond: {args.cond}')
+ out_t = model(feat_t, cond_t)
+
+ # (1,4,H,W) → crop padding → (orig_h, orig_w, 4)
+ out = out_t[0].permute(1, 2, 0).numpy()[:orig_h, :orig_w, :]
+
+ # Optional blend with albedo
+ if args.blend < 1.0:
+ h_in, w_in = albedo_rgb.shape[:2]
+ ab = albedo_rgb[:orig_h, :orig_w]
+ ones = np.ones((orig_h, orig_w, 1), dtype=np.float32)
+ src_rgba = np.concatenate([ab, ones], axis=-1)
+ out = src_rgba * (1.0 - args.blend) + out * args.blend
+
+ # --- Save ---
+ out_path = Path(args.output)
+ save_png(out_path, out)
+ print(f'Saved: {out_path}')
+
+ if args.debug_hex:
+ print('First 8 output pixels (RGBA):')
+ print_debug_hex(out)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py
index de10d6a..c790495 100644
--- a/cnn_v3/training/train_cnn_v3.py
+++ b/cnn_v3/training/train_cnn_v3.py
@@ -6,17 +6,21 @@
"""CNN v3 Training Script — U-Net + FiLM
Architecture:
- enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
- enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
- bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4
- dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
- dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
+ enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W
+ enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2
+ bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4
+ dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2
+ dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W
output sigmoid → RGBA
FiLM MLP: Linear(5→16) → ReLU → Linear(16→40)
40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4)
-Weight budget: ~5.4 KB f16 (fits ≤6 KB target)
+Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target)
+
+Training improvements:
+ --edge-loss-weight Sobel edge loss alongside MSE (default 0.1)
+ --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50)
"""
import argparse
@@ -56,7 +60,7 @@ class CNNv3(nn.Module):
self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1)
self.enc1 = nn.Conv2d(c0, c1, 3, padding=1)
- self.bottleneck = nn.Conv2d(c1, c1, 1)
+ self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2)
self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1
self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0
@@ -96,6 +100,24 @@ class CNNv3(nn.Module):
# ---------------------------------------------------------------------------
+# Loss
+# ---------------------------------------------------------------------------
+
+def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Gradient loss via Sobel filters. No VGG dependency.
+ pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device."""
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
+ dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3)
+ B, C, H, W = pred.shape
+ p = pred.view(B * C, 1, H, W)
+ t = target.view(B * C, 1, H, W)
+ return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) +
+ F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1)))
+
+
+# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
@@ -104,6 +126,10 @@ def train(args):
enc_channels = [int(c) for c in args.enc_channels.split(',')]
print(f"Device: {device}")
+ if args.single_sample:
+ args.full_image = True
+ args.batch_size = 1
+
dataset = CNNv3Dataset(
dataset_dir=args.input,
input_mode=args.input_mode,
@@ -115,6 +141,7 @@ def train(args):
detector=args.detector,
augment=True,
patch_search_window=args.patch_search_window,
+ single_sample=args.single_sample,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=False)
@@ -124,11 +151,20 @@ def train(args):
print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} "
f"params={nparams} (~{nparams*2/1024:.1f} KB f16)")
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first.
+ film_warmup = args.film_warmup_epochs
+ if film_warmup > 0:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = False
+ print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)")
+
+ optimizer = torch.optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
criterion = nn.MSELoss()
ckpt_dir = Path(args.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
start_epoch = 1
+ film_unfrozen = (film_warmup == 0)
if args.resume:
ckpt_path = Path(args.resume)
@@ -163,6 +199,15 @@ def train(args):
for epoch in range(start_epoch, args.epochs + 1):
if interrupted:
break
+
+ # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR.
+ if not film_unfrozen and epoch > film_warmup:
+ for p in model.film_mlp.parameters():
+ p.requires_grad = True
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1)
+ film_unfrozen = True
+ print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})")
+
model.train()
epoch_loss = 0.0
n_batches = 0
@@ -172,7 +217,10 @@ def train(args):
break
feat, cond, target = feat.to(device), cond.to(device), target.to(device)
optimizer.zero_grad()
- loss = criterion(model(feat, cond), target)
+ pred = model(feat, cond)
+ loss = criterion(pred, target)
+ if args.edge_loss_weight > 0.0:
+ loss = loss + args.edge_loss_weight * sobel_loss(pred, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
@@ -210,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args):
'enc_channels': [int(c) for c in args.enc_channels.split(',')],
'film_cond_dim': args.film_cond_dim,
'input_mode': args.input_mode,
+ 'edge_loss_weight': args.edge_loss_weight,
+ 'film_warmup_epochs': args.film_warmup_epochs,
},
}
@@ -222,6 +272,8 @@ def main():
p = argparse.ArgumentParser(description='Train CNN v3 (U-Net + FiLM)')
# Dataset
+ p.add_argument('--single-sample', default='', metavar='DIR',
+ help='Train on a single sample directory; implies --full-image and --batch-size 1')
p.add_argument('--input', default='training/dataset',
help='Dataset root (contains full/ or simple/ subdirs)')
p.add_argument('--input-mode', default='simple', choices=['simple', 'full'],
@@ -259,6 +311,10 @@ def main():
help='Save checkpoint every N epochs (0=disable)')
p.add_argument('--resume', default='', metavar='CKPT',
help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir')
+ p.add_argument('--edge-loss-weight', type=float, default=0.1,
+ help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)')
+ p.add_argument('--film-warmup-epochs', type=int, default=50,
+ help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)')
train(p.parse_args())