From ce6e5b99f26e4e7c69a3cacf360bd0d492de928c Mon Sep 17 00:00:00 2001 From: skal Date: Wed, 25 Mar 2026 10:05:42 +0100 Subject: feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2): effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights) - Add Sobel edge loss in training (--edge-loss-weight, default 0.1) - Add FiLM 2-phase training: freeze MLP for warmup epochs then unfreeze at lr×0.1 (--film-warmup-epochs, default 50) - Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B) - Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts - Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04 - Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py) - Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed. Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3). --- cnn_v3/docs/CNN_V3.md | 38 +--- cnn_v3/docs/HOWTO.md | 6 +- cnn_v3/docs/HOW_TO_CNN.md | 41 ++-- cnn_v3/docs/cnn_v3_architecture.png | Bin 0 -> 254783 bytes cnn_v3/docs/gen_architecture_png.py | 238 ++++++++++++++++++++++++ cnn_v3/shaders/cnn_v3_bottleneck.wgsl | 32 ++-- cnn_v3/src/cnn_v3_effect.cc | 2 +- cnn_v3/test_vectors.h | 310 +++++++++++++++++-------------- cnn_v3/tools/shaders.js | 18 +- cnn_v3/training/export_cnn_v3_weights.py | 16 +- cnn_v3/training/gen_test_vectors.py | 74 +++----- cnn_v3/training/train_cnn_v3.py | 67 ++++++- src/tests/gpu/test_cnn_v3_parity.cc | 4 +- 13 files changed, 563 insertions(+), 283 deletions(-) create mode 100644 cnn_v3/docs/cnn_v3_architecture.png create mode 100644 cnn_v3/docs/gen_architecture_png.py diff --git a/cnn_v3/docs/CNN_V3.md b/cnn_v3/docs/CNN_V3.md index 4d58811..d775e2b 100644 --- a/cnn_v3/docs/CNN_V3.md +++ b/cnn_v3/docs/CNN_V3.md @@ -27,33 +27,7 @@ CNN v3 is a next-generation post-processing effect using: ### Pipeline Overview -``` -G-Buffer (albedo, normal, depth, matID, UV) - │ - ▼ - FiLM Conditioning - (beat_time, audio_intensity, style_params) - │ → γ[], β[] per channel - ▼ - U-Net - ┌─────────────────────────────────────────┐ - │ Encoder │ - │ enc0 (H×W, 4ch) ────────────skip──────┤ - │ ↓ down (avg pool 2×2) │ - │ enc1 (H/2×W/2, 8ch) ────────skip──────┤ - │ ↓ down │ - │ bottleneck (H/4×W/4, 8ch) │ - │ │ - │ Decoder │ - │ ↑ up (nearest ×2) + skip enc1 │ - │ dec1 (H/2×W/2, 4ch) │ - │ ↑ up + skip enc0 │ - │ dec0 (H×W, 4ch) │ - └─────────────────────────────────────────┘ - │ - ▼ - output RGBA (H×W) -``` +![CNN v3 U-Net + FiLM Architecture](cnn_v3_architecture.png) FiLM is applied **inside each encoder/decoder block**, after each convolution. @@ -352,11 +326,11 @@ All f16, little-endian, same packing as v2 (`pack2x16float`). |-----------|---------|------|-----------| | enc0: Conv(20→4, 3×3) | 20×4×9=720 | +4 | 724 | | enc1: Conv(4→8, 3×3) | 4×8×9=288 | +8 | 296 | -| bottleneck: Conv(8→8, 1×1) | 8×8×1=64 | +8 | 72 | +| bottleneck: Conv(8→8, 3×3, dil=2) | 8×8×9=576 | +8 | 584 | | dec1: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 | | dec0: Conv(8→4, 3×3) | 8×4×9=288 | +4 | 292 | | FiLM MLP (5→16→40) | 5×16+16×40=720 | +16+40 | 776 | -| **Total** | | | **~3.9 KB f16** | +| **Total conv** | | | **~4.84 KB f16** | Skip connections: dec1 input = 8ch (bottleneck) + 8ch (enc1 skip) = 16ch. dec0 input = 4ch (dec1) + 4ch (enc0 skip) = 8ch. @@ -541,7 +515,7 @@ class CNNv3(nn.Module): nn.Conv2d(enc_channels[0], enc_channels[1], 3, padding=1), ]) # Bottleneck - self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 1) + self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 3, padding=2, dilation=2) # Decoder (skip connections: concat → double channels) self.dec = nn.ModuleList([ nn.Conv2d(enc_channels[1]*2, enc_channels[0], 3, padding=1), @@ -709,7 +683,7 @@ Parity results: Pass 0: pack_gbuffer.wgsl — assemble G-buffer channels into storage texture Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→4ch, 3×3) Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (4→8ch, 3×3) + downsample -Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 1×1) +Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 3×3, dilation=2) Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (16→4, 3×3) Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (8→4, 3×3) Pass 6: cnn_v3_output.wgsl — sigmoid + composite to framebuffer @@ -816,7 +790,7 @@ Status bar shows which channels are loaded. | `PACK_SHADER` | `STATIC_SHADER` | 20ch into feat_tex0 + feat_tex1 (rgba32uint each) | | `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→4, 3×3) + FiLM + ReLU; writes enc0_tex | | `ENC1_SHADER` | | Conv(4→8, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_tex (half-res) | -| `BOTTLENECK_SHADER` | | Conv(8→8, 1×1) + FiLM + ReLU; writes bn_tex | +| `BOTTLENECK_SHADER` | | Conv(8→8, 3×3, dilation=2) + ReLU; writes bn_tex | | `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(16→4, 3×3) + FiLM + ReLU | | `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(8→4, 3×3) + FiLM + ReLU | | `OUTPUT_SHADER` | | Conv(4→4, 1×1) + sigmoid → composites to canvas | diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index 1aead68..9a3efdf 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -439,10 +439,10 @@ FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each fra |-------|---------|------|-----------| | enc0 | 20×4×9=720 | +4 | 724 | | enc1 | 4×8×9=288 | +8 | 296 | -| bottleneck | 8×8×1=64 | +8 | 72 | +| bottleneck | 8×8×9=576 | +8 | 584 | | dec1 | 16×4×9=576 | +4 | 580 | | dec0 | 8×4×9=288 | +4 | 292 | -| **Total** | | | **2064 f16 = ~4 KB** | +| **Total** | | | **2476 f16 = ~4.84 KB** | **Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`): `SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`, @@ -725,7 +725,7 @@ Both tools print first 8 pixels in the same format: **Expected delta:** ≤ 1/255 (≈ 4e-3) per channel, matching the parity test (`test_cnn_v3_parity`). Larger deltas indicate a weight mismatch — re-export -with `export_cnn_v3_weights.py` and verify the `.bin` size is 3928 bytes. +with `export_cnn_v3_weights.py` and verify the `.bin` size is 4952 bytes. --- diff --git a/cnn_v3/docs/HOW_TO_CNN.md b/cnn_v3/docs/HOW_TO_CNN.md index f5f1b1a..09db97c 100644 --- a/cnn_v3/docs/HOW_TO_CNN.md +++ b/cnn_v3/docs/HOW_TO_CNN.md @@ -28,26 +28,13 @@ CNN v3 is a 2-level U-Net with FiLM conditioning, designed to run in real-time a **Architecture:** -``` -Input: 20-channel G-buffer feature textures (rgba32uint) - │ - enc0 ──── Conv(20→4, 3×3) + FiLM + ReLU ┐ full res - │ ↘ skip │ - enc1 ──── AvgPool2×2 + Conv(4→8, 3×3) + FiLM ┐ ½ res - │ ↘ skip │ - bottleneck AvgPool2×2 + Conv(8→8, 1×1) + ReLU ¼ res (no FiLM) - │ │ - dec1 ←── upsample×2 + cat(enc1 skip) + Conv(16→4, 3×3) + FiLM - │ │ ½ res - dec0 ←── upsample×2 + cat(enc0 skip) + Conv(8→4, 3×3) + FiLM + sigmoid - full res → RGBA output -``` +![CNN v3 U-Net + FiLM Architecture](cnn_v3_architecture.png) **FiLM MLP:** `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net. - Input: `[beat_phase, beat_norm, audio_intensity, style_p0, style_p1]` - Output: 40 γ/β values controlling style across all 4 FiLM layers -**Weight budget:** ~3.9 KB f16 (fits ≤6 KB target) +**Weight budget:** ~4.84 KB f16 conv (fits ≤6 KB target) **Two data paths:** - **Simple mode** — real photos with zeroed geometric channels (normal, depth, matid) @@ -307,7 +294,9 @@ uv run train_cnn_v3.py --input dataset/ --epochs 1 \ uv run train_cnn_v3.py \ --input dataset/ \ --input-mode simple \ - --epochs 200 + --epochs 200 \ + --edge-loss-weight 0.1 \ + --film-warmup-epochs 50 ``` **Blender G-buffer training:** @@ -315,7 +304,9 @@ uv run train_cnn_v3.py \ uv run train_cnn_v3.py \ --input dataset/ \ --input-mode full \ - --epochs 200 + --epochs 200 \ + --edge-loss-weight 0.1 \ + --film-warmup-epochs 50 ``` **Full-image mode (better global coherence, slower):** @@ -360,12 +351,14 @@ uv run train_cnn_v3.py \ | `--checkpoint-dir DIR` | `checkpoints/` | Set per-experiment | | `--checkpoint-every N` | 50 | 0 to disable intermediate checkpoints | | `--resume [CKPT]` | — | Resume from checkpoint path; if path missing, uses latest in `--checkpoint-dir` | +| `--edge-loss-weight F` | 0.1 | Sobel gradient loss weight alongside MSE; improves style/edge capture; 0=MSE only | +| `--film-warmup-epochs N` | 50 | Freeze FiLM MLP for first N epochs (phase-1), then unfreeze at lr×0.1; 0=joint training | ### Architecture at startup The model prints its parameter count: ``` -Model: enc=[4, 8] film_cond_dim=5 params=2740 (~5.4 KB f16) +Model: enc=[4, 8] film_cond_dim=5 params=3252 (~6.4 KB f16) ``` If `params` is much higher, `--enc-channels` was changed; update C++ constants accordingly. @@ -489,7 +482,7 @@ Use `--html-output PATH` to write to a different `weights.js` location. Output files are registered in `workspaces/main/assets.txt` as: ``` -WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 3928 bytes)" +WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 4952 bytes)" WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP weights (f32, 3104 bytes)" ``` @@ -501,10 +494,10 @@ WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP w |-------|-----------|-------| | enc0 Conv(20→4,3×3)+bias | 724 | — | | enc1 Conv(4→8,3×3)+bias | 296 | — | -| bottleneck Conv(8→8,1×1)+bias | 72 | — | +| bottleneck Conv(8→8,3×3,dil=2)+bias | 584 | — | | dec1 Conv(16→4,3×3)+bias | 580 | — | | dec0 Conv(8→4,3×3)+bias | 292 | — | -| **Total** | **1964 f16** | **3928 bytes** | +| **Total** | **2476 f16** | **4952 bytes** | **`cnn_v3_film_mlp.bin`** — FiLM MLP weights as raw f32, row-major: @@ -534,8 +527,8 @@ Checkpoint: epoch=200 loss=0.012345 enc_channels=[4, 8] film_cond_dim=5 cnn_v3_weights.bin - 1964 f16 values → 982 u32 → 3928 bytes - Upload via CNNv3Effect::upload_weights(queue, data, 3928) + 2476 f16 values → 1238 u32 → 4952 bytes + Upload via CNNv3Effect::upload_weights(queue, data, 4952) cnn_v3_film_mlp.bin L0: weight (16, 5) + bias (16,) @@ -824,7 +817,7 @@ all geometric channels (normal, depth, depth_grad, mat_id, prev) = 0. ### Pitfalls - `rgba32uint` and `rgba16float` textures both need `STORAGE_BINDING | TEXTURE_BINDING` usage. -- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672). +- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184). - Uniform buffer layouts must match WGSL `Params` structs exactly (padding included). --- diff --git a/cnn_v3/docs/cnn_v3_architecture.png b/cnn_v3/docs/cnn_v3_architecture.png new file mode 100644 index 0000000..2116c2b Binary files /dev/null and b/cnn_v3/docs/cnn_v3_architecture.png differ diff --git a/cnn_v3/docs/gen_architecture_png.py b/cnn_v3/docs/gen_architecture_png.py new file mode 100644 index 0000000..bd60a97 --- /dev/null +++ b/cnn_v3/docs/gen_architecture_png.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["matplotlib"] +# /// +"""Generate CNN v3 U-Net + FiLM architecture diagram → cnn_v3_architecture.png""" + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +from matplotlib.path import Path +import matplotlib.patheffects as pe + +# --------------------------------------------------------------------------- +# Canvas +# --------------------------------------------------------------------------- +BG = '#0F172A' +fig = plt.figure(figsize=(17, 10), facecolor=BG) +ax = fig.add_axes([0, 0, 1, 1], facecolor=BG) +ax.set_xlim(0, 17) +ax.set_ylim(0, 10) +ax.axis('off') + +# --------------------------------------------------------------------------- +# Palette +# --------------------------------------------------------------------------- +C_ENC = '#3B82F6' # encoder — blue +C_BN = '#8B5CF6' # bottleneck — violet +C_DEC = '#10B981' # decoder — emerald +C_MLP = '#EC4899' # FiLM MLP — pink +C_FILM = '#F59E0B' # FiLM γ/β arrows — amber +C_IO = '#475569' # input/output — slate +C_SKP = '#F97316' # skip connections — orange +C_ARR = '#94A3B8' # main flow arrows — cool-grey +C_TXT = '#F1F5F9' # text — near-white +C_DIM = '#64748B' # dim labels — slate + +# --------------------------------------------------------------------------- +# Geometry — two-column U layout +# --------------------------------------------------------------------------- +EX, DX = 3.8, 13.2 # encoder / decoder centre-x +BX = 8.5 # bottleneck centre-x + +BW = 4.6 # block width (enc / dec) +BH = 0.95 # block height (enc / dec) +BW_BN = 5.4 # bottleneck wider +BH_BN = 0.95 +BH_IO = 0.72 + +# y positions (top = high number) +Y_IN = 8.90 +Y_E0 = 7.50 # enc0 full res +Y_E1 = 5.80 # enc1 ½ res +Y_BN = 3.20 # bottleneck ¼ res +Y_D1 = 5.80 # dec1 ½ res +Y_D0 = 7.50 # dec0 full res +Y_OUT = 8.90 + +Y_MLP = 1.25 # FiLM MLP + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def box(cx, cy, w, h, color, line1, line2='', lfs=9.5, sfs=8.0, alpha=0.92): + r = FancyBboxPatch((cx - w/2, cy - h/2), w, h, + boxstyle='round,pad=0.10', + fc=color, ec='white', lw=1.3, alpha=alpha, zorder=3) + ax.add_patch(r) + dy = 0.18 if line2 else 0 + ax.text(cx, cy + dy, line1, ha='center', va='center', + fontsize=lfs, fontweight='bold', color='white', zorder=4, + fontfamily='DejaVu Sans Mono') + if line2: + ax.text(cx, cy - 0.18, line2, ha='center', va='center', + fontsize=sfs, color='white', alpha=0.80, zorder=4) + + +def arrow(x0, y0, x1, y1, color=C_ARR, lw=1.8, dashed=False, + rad=0.0, label='', lx=None, ly=None): + ls = (0, (5, 3)) if dashed else 'solid' + cs = f'arc3,rad={rad}' if rad else 'arc3,rad=0' + ax.annotate('', xy=(x1, y1), xytext=(x0, y0), + arrowprops=dict(arrowstyle='->', color=color, lw=lw, + linestyle=ls, mutation_scale=13, + connectionstyle=cs), + zorder=2) + if label: + ax.text(lx if lx else (x0+x1)/2, + ly if ly else (y0+y1)/2, + label, ha='center', va='center', fontsize=7.5, + color=color, zorder=5, + bbox=dict(fc=BG, ec='none', alpha=0.75, + boxstyle='round,pad=0.15')) + + +def dim_label(x, y, txt): + ax.text(x, y, txt, ha='center', va='center', + fontsize=8.5, color=C_DIM, style='italic') + + +# --------------------------------------------------------------------------- +# Blocks +# --------------------------------------------------------------------------- + +box(EX, Y_IN, BW, BH_IO, C_IO, 'G-Buffer Features', + '20 channels · full res') + +box(EX, Y_E0, BW, BH, C_ENC, 'enc0 Conv(20→4, 3×3) + FiLM + ReLU', + 'full res · 4 ch') + +box(EX, Y_E1, BW, BH, C_ENC, 'enc1 Conv(4→8, 3×3) + FiLM + ReLU', + '½ res · 8 ch · (AvgPool↓ on input)') + +box(BX, Y_BN, BW_BN, BH_BN, C_BN, + 'bottleneck Conv(8→8, 3×3, dilation=2) + ReLU', + '¼ res · 8 ch · no FiLM · effective RF ≈ 10 px @ ½res') + +box(DX, Y_D1, BW, BH, C_DEC, 'dec1 Conv(16→4, 3×3) + FiLM + ReLU', + '½ res · 4 ch · (upsample↑ + cat enc1 skip)') + +box(DX, Y_D0, BW, BH, C_DEC, 'dec0 Conv(8→4, 3×3) + FiLM + sigmoid', + 'full res · 4 ch · (upsample↑ + cat enc0 skip)') + +box(DX, Y_OUT, BW, BH_IO, C_IO, 'RGBA Output', + '4 channels · full res') + +box(BX, Y_MLP, 9.2, 1.10, C_MLP, + 'FiLM MLP Linear(5→16) → ReLU → Linear(16→40)', + 'in: beat_phase · beat_norm · audio_intensity · style_p0 · style_p1' + ' → γ/β (×2) for enc0(4) enc1(8) dec1(4) dec0(4) = 40 values', + sfs=7.5) + +# --------------------------------------------------------------------------- +# Main-flow arrows +# --------------------------------------------------------------------------- + +# Input → enc0 +arrow(EX, Y_IN - BH_IO/2 - .04, EX, Y_E0 + BH/2 + .04) + +# enc0 → enc1 (AvgPool label beside) +arrow(EX, Y_E0 - BH/2 - .04, EX, Y_E1 + BH/2 + .04, + label='AvgPool\n 2×2', lx=EX + 0.72, ly=(Y_E0 + Y_E1)/2) + +# enc1 → bottleneck (curve down-right) +arrow(EX, Y_E1 - BH/2 - .04, + BX - BW_BN/2 - .04, Y_BN, + rad=-0.28, + label='AvgPool\n 2×2', lx=(EX + BX)/2 - 0.5, ly=Y_BN + 0.90) + +# bottleneck → dec1 (curve right-up) +arrow(BX + BW_BN/2 + .04, Y_BN, + DX, Y_D1 - BH/2 - .04, + rad=-0.28, + label='upsample\n 2×', lx=(BX + DX)/2 + 0.5, ly=Y_D1 - 0.90) + +# dec1 → dec0 +arrow(DX, Y_D1 + BH/2 + .04, DX, Y_D0 - BH/2 - .04, + label='upsample\n 2×', lx=DX - 0.72, ly=(Y_D1 + Y_D0)/2) + +# dec0 → output +arrow(DX, Y_D0 + BH/2 + .04, DX, Y_OUT - BH_IO/2 - .04) + +# --------------------------------------------------------------------------- +# Skip connections +# --------------------------------------------------------------------------- + +# enc0 skip → dec0 +arrow(EX + BW/2 + .04, Y_E0, + DX - BW/2 - .04, Y_D0, + color=C_SKP, lw=1.6, dashed=True, + label='skip enc0 (4 ch)', ly=Y_E0 + 0.40) + +# enc1 skip → dec1 +arrow(EX + BW/2 + .04, Y_E1, + DX - BW/2 - .04, Y_D1, + color=C_SKP, lw=1.6, dashed=True, + label='skip enc1 (8 ch)', ly=Y_E1 + 0.40) + +# --------------------------------------------------------------------------- +# FiLM γ/β arrows (MLP → each FiLM layer) +# --------------------------------------------------------------------------- +film_targets = [ + (EX, Y_E0 - BH/2 - .04), # enc0 bottom + (EX, Y_E1 - BH/2 - .04), # enc1 bottom + (DX, Y_D1 - BH/2 - .04), # dec1 bottom + (DX, Y_D0 - BH/2 - .04), # dec0 bottom +] +for tx, ty in film_targets: + ax.annotate('', xy=(tx, ty), + xytext=(BX + (tx - BX) * 0.05, Y_MLP + 0.55 + .04), + arrowprops=dict(arrowstyle='->', color=C_FILM, lw=1.2, + linestyle=(0, (3, 3)), mutation_scale=10, + connectionstyle='arc3,rad=0.18'), + zorder=2) + +ax.text(8.5, 4.30, 'γ / β', ha='center', va='center', + fontsize=9, color=C_FILM, alpha=0.85, style='italic', zorder=5) + +# --------------------------------------------------------------------------- +# Resolution markers (left margin) +# --------------------------------------------------------------------------- +for y, lbl in [(Y_E0, 'full res'), (Y_E1, '½ res'), (Y_BN, '¼ res')]: + dim_label(0.62, y, lbl) + ax.plot([0.95, 1.10], [y, y], color=C_DIM, lw=0.8, zorder=1) + +# --------------------------------------------------------------------------- +# Legend +# --------------------------------------------------------------------------- +legend_items = [ + mpatches.Patch(fc=C_ENC, ec='white', lw=0.8, label='Encoder'), + mpatches.Patch(fc=C_BN, ec='white', lw=0.8, label='Bottleneck'), + mpatches.Patch(fc=C_DEC, ec='white', lw=0.8, label='Decoder'), + mpatches.Patch(fc=C_MLP, ec='white', lw=0.8, label='FiLM MLP'), + mpatches.Patch(fc=C_IO, ec='white', lw=0.8, label='I/O'), + plt.Line2D([0], [0], color=C_SKP, lw=1.6, ls='--', label='Skip connection'), + plt.Line2D([0], [0], color=C_FILM, lw=1.2, ls=(0, (3,3)), label='FiLM γ/β'), +] +leg = ax.legend(handles=legend_items, loc='lower right', + bbox_to_anchor=(0.99, 0.01), + framealpha=0.15, facecolor=BG, edgecolor=C_DIM, + fontsize=8, labelcolor=C_TXT, ncol=1) + +# --------------------------------------------------------------------------- +# Title +# --------------------------------------------------------------------------- +ax.text(8.5, 9.68, 'CNN v3 — U-Net + FiLM Architecture', + ha='center', va='center', fontsize=14, fontweight='bold', color=C_TXT) + +# --------------------------------------------------------------------------- +# Save +# --------------------------------------------------------------------------- +import pathlib +out = pathlib.Path(__file__).parent / 'cnn_v3_architecture.png' +fig.savefig(out, dpi=180, bbox_inches='tight', facecolor=BG, edgecolor='none') +print(f'Saved → {out} ({out.stat().st_size // 1024} KB)') diff --git a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl index e24586f..e30682b 100644 --- a/cnn_v3/shaders/cnn_v3_bottleneck.wgsl +++ b/cnn_v3/shaders/cnn_v3_bottleneck.wgsl @@ -1,17 +1,18 @@ // CNN v3 — Bottleneck -// AvgPool2x2(enc1) + Conv(8->8, 1x1) + ReLU (no FiLM) +// AvgPool2x2(enc1) + Conv(8->8, 3x3, dilation=2) + ReLU (no FiLM) // -// Input: enc1_tex (rgba32uint, 8xf16) half-res -// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims) +// Input: enc1_tex (rgba32uint, 8xf16) half-res +// Output: bottleneck_out (rgba32uint, 8xf16) quarter-res (dispatch at quarter-res dims) // // Weight layout (f16, OIHW + bias): -// [0 .. 8*8*1) conv: w[out][in] (1x1 kernel) -// [64 .. +8) bias: b[out] +// [0 .. 8*8*9) conv: w[out][in][ky*3+kx] (3x3 kernel, OIHW) +// [576 .. +8) bias: b[out] #include "cnn_v3/common" -const BN_IN: u32 = 8u; -const BN_OUT: u32 = 8u; +const BN_IN: u32 = 8u; +const BN_OUT: u32 = 8u; +const BN_DILATION: i32 = 2; struct Params { weight_offset: u32, @@ -24,7 +25,7 @@ struct Params { @group(0) @binding(3) var bottleneck_out: texture_storage_2d; // Avg-pool 2x2 from enc1_tex at quarter-res coord qcoord. -// Returns zeros for OOB quarter-res coords (zero-padding for the 1x1 conv). +// Returns zeros for OOB quarter-res coords (zero-padding for the 3x3 conv). fn load_enc1_avg(qcoord: vec2i, half_dims: vec2i) -> array { let quart_dims = half_dims / 2; if (qcoord.x < 0 || qcoord.y < 0 || qcoord.x >= quart_dims.x || qcoord.y >= quart_dims.y) { @@ -50,14 +51,19 @@ fn bottleneck_main(@builtin(global_invocation_id) id: vec3u) { let coord = vec2i(id.xy); if (coord.x >= quart_dims.x || coord.y >= quart_dims.y) { return; } - let wo = params.weight_offset; - let feat = load_enc1_avg(coord, half_dims); + let wo = params.weight_offset; var out: array; for (var o: u32 = 0u; o < BN_OUT; o++) { - var sum = get_w(wo, BN_OUT * BN_IN + o); // bias (1x1 kernel: no spatial idx) - for (var i: u32 = 0u; i < BN_IN; i++) { - sum += get_w(wo, o * BN_IN + i) * feat[i]; + var sum = get_w(wo, BN_OUT * BN_IN * 9u + o); // bias (at end of 3x3 conv weights) + for (var ky: i32 = -1; ky <= 1; ky++) { + for (var kx: i32 = -1; kx <= 1; kx++) { + let feat = load_enc1_avg(coord + vec2i(kx, ky) * BN_DILATION, half_dims); + let ki = u32(ky + 1) * 3u + u32(kx + 1); + for (var i: u32 = 0u; i < BN_IN; i++) { + sum += get_w(wo, o * BN_IN * 9u + i * 9u + ki) * feat[i]; + } + } } out[o] = max(0.0, sum); } diff --git a/cnn_v3/src/cnn_v3_effect.cc b/cnn_v3/src/cnn_v3_effect.cc index bfbb17b..1391eba 100644 --- a/cnn_v3/src/cnn_v3_effect.cc +++ b/cnn_v3/src/cnn_v3_effect.cc @@ -25,7 +25,7 @@ // static const uint32_t kEnc0Weights = 20 * 4 * 9 + 4; // Conv(20→4,3×3)+bias static const uint32_t kEnc1Weights = 4 * 8 * 9 + 8; // Conv(4→8,3×3)+bias -static const uint32_t kBnWeights = 8 * 8 * 1 + 8; // Conv(8→8,1×1)+bias +static const uint32_t kBnWeights = 8 * 8 * 9 + 8; // Conv(8→8,3×3,dilation=2)+bias static const uint32_t kDec1Weights = 16 * 4 * 9 + 4; // Conv(16→4,3×3)+bias static const uint32_t kDec0Weights = 8 * 4 * 9 + 4; // Conv(8→4,3×3)+bias diff --git a/cnn_v3/test_vectors.h b/cnn_v3/test_vectors.h index 6d1abc5..3e256a3 100644 --- a/cnn_v3/test_vectors.h +++ b/cnn_v3/test_vectors.h @@ -9,78 +9,78 @@ static const int kCnnV3TestH = 8; // 256 u32 values static const uint32_t kCnnV3TestFeat0U32[256] = { - 0x2ccd39ebu, 0x3acb39d7u, 0x3814378fu, 0x3bc134ffu, 0x35e739ddu, 0x33073198u, 0x3b8a376cu, 0x32e339e0u, - 0x360d3ae0u, 0x33bc3ad0u, 0x3b6a38f3u, 0x398b3420u, 0x30d23be6u, 0x39652da8u, 0x2c8f3570u, 0x3a08379bu, - 0x355c3490u, 0x38293ac4u, 0x37243abeu, 0x39353ba0u, 0x3b152f6bu, 0x308837d4u, 0x398030e3u, 0x34962a10u, - 0x370f3079u, 0x382d36a1u, 0x281a3479u, 0x35fb38eau, 0x2ef43936u, 0x33f2230eu, 0x364e374eu, 0x360c3a7bu, - 0x38c0383eu, 0x381f2597u, 0x36be3584u, 0x3a432e6bu, 0x25b33b8au, 0x3a1c38d0u, 0x3a4d348fu, 0x3b6f390fu, - 0x296c3bd9u, 0x3860371eu, 0x356130b2u, 0x24283be7u, 0x3abe373du, 0x37ad352fu, 0x37993bd3u, 0x2a9f3031u, - 0x34413b90u, 0x2dce3808u, 0x3b7136c7u, 0x3bc53805u, 0x38093424u, 0x372c3ae0u, 0x3ad83479u, 0x383f363du, - 0x31f83bd0u, 0x27f434d3u, 0x32683645u, 0x31cd3971u, 0x34373966u, 0x359535afu, 0x377739bcu, 0x3ad235c8u, - 0x32d83893u, 0x357b3b33u, 0x37ea28fdu, 0x33a22fefu, 0x302f39fau, 0x3b7f3a75u, 0x39af38dau, 0x3bf139b5u, - 0x31363577u, 0x38443827u, 0x38e831b1u, 0x3b6c233bu, 0x2910343cu, 0x33b02eeeu, 0x28333462u, 0x322d3478u, - 0x362a360fu, 0x353f356du, 0x26742dbeu, 0x3a0e3278u, 0x3b6e3bedu, 0x38413809u, 0x3a313509u, 0x3ac13a1eu, - 0x36f33b2au, 0x3a743a23u, 0x3b6f34efu, 0x3bf42e0au, 0x2df83a29u, 0x28603940u, 0x3a653a29u, 0x3adb38d2u, - 0x346a2e44u, 0x296f36a0u, 0x343e372cu, 0x36cd3649u, 0x34533b09u, 0x36d13b26u, 0x3805353fu, 0x341e36afu, - 0x30dc3805u, 0x388735a2u, 0x3a97369du, 0x3bc2341cu, 0x3bbe3a47u, 0x308c3ab5u, 0x31703836u, 0x38ac3a8cu, - 0x3b703437u, 0x38832f5fu, 0x2b8839c5u, 0x3a8738c8u, 0x38192c52u, 0x394e3423u, 0x3b7f2f98u, 0x31f43b28u, - 0x38b3352cu, 0x371539bfu, 0x2eaa3100u, 0x37493c00u, 0x37b83afbu, 0x2d9e3b61u, 0x3b702f4cu, 0x35093b94u, - 0x373d35afu, 0x321536a9u, 0x340e3b30u, 0x2c4c39a4u, 0x393b28f6u, 0x393e356du, 0x3b992e04u, 0x3b0339fdu, - 0x351f305eu, 0x384c35e5u, 0x2bc334c0u, 0x341335e7u, 0x324d362du, 0x39043431u, 0x35873636u, 0x3a2d3845u, - 0x38b33610u, 0x382d3bbbu, 0x3a593b47u, 0x36de2b84u, 0x3be53996u, 0x2df03756u, 0x300d387fu, 0x38103a03u, - 0x3af439cau, 0x38e63908u, 0x3abd3a09u, 0x28aa3af4u, 0x32ec3873u, 0x39303ae2u, 0x320536b9u, 0x39a1356du, - 0x2dfd328au, 0x3a1d3b1bu, 0x34ad3265u, 0x39aa3bc7u, 0x34ec38e2u, 0x290f34c9u, 0x298739d4u, 0x39d61cf9u, - 0x3a0d3b97u, 0x37c7378cu, 0x353236fau, 0x36e6382cu, 0x3b2f38c9u, 0x2d0a3bf6u, 0x31c83628u, 0x349935a2u, - 0x3a1d3196u, 0x3b5b37f1u, 0x2c49282cu, 0x2d233674u, 0x3be33434u, 0x325732b0u, 0x37f83897u, 0x360738a5u, - 0x306f3a9du, 0x398536dbu, 0x35ea3af2u, 0x2c6d388bu, 0x2c6d3173u, 0x349c39d3u, 0x2c4039cau, 0x3aaf3ae6u, - 0x26152db1u, 0x3ad42b34u, 0x38633383u, 0x3a5d36d2u, 0x380137d3u, 0x30ce3beau, 0x2aa03aa5u, 0x3b1737a4u, - 0x397b3952u, 0x36b23437u, 0x382c35deu, 0x353b3765u, 0x340334e3u, 0x30cc35d7u, 0x38d13afau, 0x398d3048u, - 0x339a3ac8u, 0x206930d2u, 0x3a192a0cu, 0x29bf3be6u, 0x2c9939fcu, 0x3a0c38bdu, 0x219935bfu, 0x3bee38c3u, - 0x3210341fu, 0x38712feeu, 0x3a5738c6u, 0x3b243a06u, 0x33ea3a72u, 0x34c23872u, 0x3b753547u, 0x3bcc3975u, - 0x384d36acu, 0x2ede37cbu, 0x38393393u, 0x3b742c50u, 0x32562fedu, 0x2e343a1fu, 0x39ce3b34u, 0x39892c64u, - 0x3a0f390eu, 0x39bf3aa0u, 0x352938b0u, 0x3ba83994u, 0x395138b3u, 0x3a0d36feu, 0x31223bfbu, 0x3851327au, - 0x389337b4u, 0x36782a48u, 0x38ae38aau, 0x39c33942u, 0x3a523922u, 0x384d3900u, 0x2e7a38d9u, 0x3838345fu, - 0x396f3afcu, 0x38bd2dc9u, 0x39df3318u, 0x38bf3a9fu, 0x356b38bbu, 0x3aea3724u, 0x382839c9u, 0x2a7335e4u, + 0x322d3094u, 0x3b8e35f9u, 0x3384380bu, 0x356a2ec5u, 0x36223a87u, 0x3a6c3a2eu, 0x38df2f18u, 0x366f3bacu, + 0x3b632a85u, 0x382d3b12u, 0x32c7386bu, 0x37fa3a8eu, 0x39772856u, 0x3aa23b9bu, 0x2ad3346fu, 0x3b7f3b86u, + 0x39233842u, 0x36b73767u, 0x2e312fa5u, 0x3ab13373u, 0x334130abu, 0x32e23864u, 0x38823139u, 0x390235e4u, + 0x30d53b4cu, 0x3b383b4fu, 0x390d3aa2u, 0x391d38f6u, 0x24383986u, 0x38af3baau, 0x36093a40u, 0x38142259u, + 0x36fe380fu, 0x33ff356fu, 0x36013838u, 0x31893bc2u, 0x34b4351du, 0x37fd3859u, 0x3b9a3926u, 0x398b3490u, + 0x34332669u, 0x27ef376bu, 0x396d38f1u, 0x382239f9u, 0x365638d5u, 0x2e662948u, 0x3bf7393du, 0x3876240cu, + 0x3a9d3a02u, 0x38f6385du, 0x3adf3993u, 0x3b692fd4u, 0x3ab126a9u, 0x323a2ce9u, 0x37201bfau, 0x3150355au, + 0x36703738u, 0x3b253a24u, 0x2ff63938u, 0x3b4e34bau, 0x36822daeu, 0x3b9b3b8au, 0x39573694u, 0x3a07374fu, + 0x309d280bu, 0x337138eau, 0x359f3954u, 0x3a8e3b18u, 0x3a2f37e3u, 0x37e83457u, 0x3ae33252u, 0x3b383a96u, + 0x3bad3b05u, 0x3b74334fu, 0x36c33892u, 0x357b387cu, 0x33b9349eu, 0x37f22d47u, 0x390b3b1au, 0x36dc382bu, + 0x32b2376eu, 0x32593a95u, 0x3a1439bcu, 0x3ae73899u, 0x3b0e34a8u, 0x3a6439d6u, 0x3ac53951u, 0x36b93bf2u, + 0x39f53a83u, 0x3b6a373fu, 0x38863650u, 0x333a2ec8u, 0x36583abau, 0x33df364eu, 0x3a7237deu, 0x2c7d3b29u, + 0x377a3899u, 0x372838eau, 0x378d3661u, 0x380238a8u, 0x3a8b378eu, 0x357639f7u, 0x3ad43a68u, 0x38a930e9u, + 0x39ea3491u, 0x395f33a4u, 0x38173415u, 0x361a3b97u, 0x3be53b02u, 0x314f3b00u, 0x281d3a8fu, 0x3af7364bu, + 0x38433983u, 0x3a803635u, 0x377f39adu, 0x335c3b24u, 0x39243174u, 0x33ea3bc7u, 0x307733fdu, 0x333f3ae2u, + 0x3bed3807u, 0x38742237u, 0x3a763819u, 0x369135afu, 0x39ed3160u, 0x30603a47u, 0x3b25364cu, 0x34c8198bu, + 0x35583871u, 0x375c345au, 0x383d31cfu, 0x389a39a7u, 0x3ac12df6u, 0x3a1e3199u, 0x3a4335c5u, 0x31f9329au, + 0x283737f4u, 0x39cb3336u, 0x2d2c3ab3u, 0x3a613b0eu, 0x39963af5u, 0x38333965u, 0x3b5a3939u, 0x350d2e6fu, + 0x3b8f2ca3u, 0x39673720u, 0x3bee3abbu, 0x3a65312du, 0x2a423b19u, 0x35ad3a08u, 0x381d3930u, 0x30543428u, + 0x2e9d2f7cu, 0x359f391au, 0x398932efu, 0x3850397fu, 0x362b3b7bu, 0x2ccf3ab0u, 0x3be839ebu, 0x38a33ac6u, + 0x35a73904u, 0x3a2a3970u, 0x37e13bfcu, 0x38c42bd9u, 0x33d52f9eu, 0x39d93543u, 0x314e31e2u, 0x3afc29c1u, + 0x291d398cu, 0x3878273eu, 0x38c63485u, 0x3b6336f4u, 0x396f349bu, 0x3ba62aebu, 0x39ea3bd9u, 0x330a3772u, + 0x39e43a80u, 0x3738331au, 0x3a9c3768u, 0x39253979u, 0x34543933u, 0x29d835f3u, 0x36ee3a4cu, 0x33da3703u, + 0x38b432b4u, 0x2c1c3371u, 0x36063a24u, 0x36e73615u, 0x35223a85u, 0x3b843a10u, 0x36e83949u, 0x375439fbu, + 0x383436a1u, 0x2eac3515u, 0x2fed36a3u, 0x38753691u, 0x28a33b72u, 0x375338f9u, 0x33fc2530u, 0x32f02f95u, + 0x366c3465u, 0x140e383bu, 0x2dfd312eu, 0x35443866u, 0x33193863u, 0x3b882634u, 0x300f2eefu, 0x3bda30b1u, + 0x38e238f1u, 0x2da93be5u, 0x32873bccu, 0x36b938fcu, 0x3b733625u, 0x3bfa30c6u, 0x39313611u, 0x2b5f3bbeu, + 0x388b3b62u, 0x30c639a3u, 0x39633844u, 0x30f6374du, 0x3ad633d0u, 0x39ac286au, 0x1faa3bffu, 0x39653127u, + 0x38b82baeu, 0x38b53979u, 0x399435d8u, 0x32a538c1u, 0x3b0e3881u, 0x378c3956u, 0x2d7f3525u, 0x21ba33d4u, + 0x331f3be5u, 0x31663a85u, 0x36b1348au, 0x3a633531u, 0x3b013ba9u, 0x3a3730eau, 0x3b4f30bcu, 0x35623825u, + 0x220c3106u, 0x3b5033efu, 0x3bc23a61u, 0x38bd2e73u, 0x3858341du, 0x34893521u, 0x31de3897u, 0x39353782u, + 0x3b72301au, 0x3a8e380cu, 0x39ae393bu, 0x3b0039bbu, 0x347438e9u, 0x38da2e5eu, 0x33b92c3fu, 0x38642bc5u, }; // 256 u32 values static const uint32_t kCnnV3TestFeat1U32[256] = { - 0xc863b415u, 0x249c220fu, 0x603452c6u, 0x00000000u, 0x316a194cu, 0x291db2cbu, 0x5f96105bu, 0x00000000u, - 0xeb343d39u, 0xf1b365e6u, 0x61b71b05u, 0x00000000u, 0x8151bb9eu, 0xfc56bec5u, 0x3c1e7c24u, 0x00000000u, - 0xf1d859a5u, 0x1b1270e5u, 0x39d19474u, 0x00000000u, 0x569b30dcu, 0x097e59b6u, 0xd0d3b912u, 0x00000000u, - 0xdafc8a80u, 0x6222c0d8u, 0xd61d6364u, 0x00000000u, 0xc5c2f0c4u, 0xcd28e9d7u, 0xcd7e12c4u, 0x00000000u, - 0x92cfbc01u, 0x1c5ebffdu, 0xec699bb5u, 0x00000000u, 0x9bd12023u, 0xe6b94175u, 0xf58751d1u, 0x00000000u, - 0x2fe9e259u, 0x66f28558u, 0x314748e3u, 0x00000000u, 0x0d0aabfcu, 0xf7666903u, 0xec5d90aau, 0x00000000u, - 0xee86a635u, 0xe237f413u, 0xa61606fcu, 0x00000000u, 0x85ab0fd7u, 0xfdd13bdbu, 0x8d6075e2u, 0x00000000u, - 0xa476623cu, 0x3634aa37u, 0xbf284477u, 0x00000000u, 0xd1c78653u, 0xadb3feedu, 0x7fa4408au, 0x00000000u, - 0x32a77b6au, 0x08ac3716u, 0xa0976732u, 0x00000000u, 0xaeda1174u, 0xc5ca1e59u, 0xf353b939u, 0x00000000u, - 0x7f53105cu, 0xd44334dfu, 0xb75edbe4u, 0x00000000u, 0x46f67512u, 0xd859d32du, 0x0da6b677u, 0x00000000u, - 0x9950dc38u, 0xf0badec3u, 0xa8b1d193u, 0x00000000u, 0xefe357bdu, 0x0e606587u, 0x884c5ed2u, 0x00000000u, - 0xc7d63411u, 0xa46ee9f4u, 0xe16ad66fu, 0x00000000u, 0x766cf523u, 0xaebf1396u, 0x6b75be3bu, 0x00000000u, - 0xdf433db5u, 0x1e942c35u, 0x410dffe5u, 0x00000000u, 0x18c4cc46u, 0xb3bcd975u, 0x3b94557eu, 0x00000000u, - 0x512fefb1u, 0xd62e1684u, 0x5c34ef2bu, 0x00000000u, 0x25554402u, 0x055e5375u, 0x3a08ec40u, 0x00000000u, - 0xea28d1a6u, 0x8c71f892u, 0xfead5d3du, 0x00000000u, 0x3712d6e9u, 0x59fa8772u, 0x29c7e9cdu, 0x00000000u, - 0x65fc32ecu, 0x90357e43u, 0xcee18a15u, 0x00000000u, 0x5e3b5c50u, 0xc583129du, 0xa04bf996u, 0x00000000u, - 0x4ab43782u, 0xe9864a08u, 0x6f2ab1c6u, 0x00000000u, 0x26a77c61u, 0xf673703cu, 0xe9d6c9cfu, 0x00000000u, - 0x0caebeeeu, 0xe709951fu, 0xf2875771u, 0x00000000u, 0xd43f1577u, 0x41477617u, 0xa19bf431u, 0x00000000u, - 0x89ca27c9u, 0x9ec1ee6cu, 0x9dcf44adu, 0x00000000u, 0xa3a370ddu, 0x83958e74u, 0xb0c45102u, 0x00000000u, - 0x86cfafcau, 0x04382d70u, 0x09083cf1u, 0x00000000u, 0xf5458e26u, 0xe8c4a35bu, 0x95ea20cbu, 0x00000000u, - 0x2cb1e624u, 0xc80e252fu, 0x24aeadb9u, 0x00000000u, 0x60958ae8u, 0x5471b135u, 0x032c76bcu, 0x00000000u, - 0xce983976u, 0x827df87du, 0x50f5f0adu, 0x00000000u, 0x81d7362fu, 0x00000e99u, 0x6fde87aeu, 0x00000000u, - 0x85033eb4u, 0x56f7b265u, 0xd493d37cu, 0x00000000u, 0x3ff49a3cu, 0x23487a39u, 0x870d2e4fu, 0x00000000u, - 0xe3249135u, 0x60123a68u, 0x0befa03du, 0x00000000u, 0xf84d74b5u, 0x71bd7da9u, 0x2c44f6cbu, 0x00000000u, - 0x9d98f068u, 0x51d59a46u, 0xf0131dceu, 0x00000000u, 0x4b40fe50u, 0x8cd5b0fbu, 0x8b164f67u, 0x00000000u, - 0x3e10a2d3u, 0x7fd0d4b7u, 0x1bec231fu, 0x00000000u, 0xa4cc2cd6u, 0xc22121ffu, 0xf33350e7u, 0x00000000u, - 0x536659b7u, 0x49043fc2u, 0x8c7ec0d7u, 0x00000000u, 0xb1597a41u, 0xfe1228f2u, 0x066908e4u, 0x00000000u, - 0x3d0194e7u, 0x432be415u, 0x4160b66fu, 0x00000000u, 0x76b6560au, 0xdf770ab8u, 0x07ef4642u, 0x00000000u, - 0xd0dafe5cu, 0x9e1f95f4u, 0x9d7dbecdu, 0x00000000u, 0xada5c397u, 0x1d8b6a84u, 0xbf29cf46u, 0x00000000u, - 0x3f858ef0u, 0x843e3a0cu, 0xad47e23fu, 0x00000000u, 0x9a9c1e18u, 0x52b851a8u, 0x65648845u, 0x00000000u, - 0x79fca3a8u, 0x0a8f8f09u, 0xb9dde8cbu, 0x00000000u, 0x199671dfu, 0x7565be28u, 0xa7add019u, 0x00000000u, - 0x14948e21u, 0xfedcb64du, 0x6091bc31u, 0x00000000u, 0x040bae5bu, 0xa89c3b59u, 0x8ebdcac3u, 0x00000000u, + 0xee7c0a1du, 0x290beb5au, 0x34aedb72u, 0x00000000u, 0x9c43a772u, 0x9ac02fbau, 0xca762320u, 0x00000000u, + 0xed95234bu, 0xd266c660u, 0x23e572b0u, 0x00000000u, 0x4f3e3e4cu, 0xe9f050c2u, 0x8c8848c4u, 0x00000000u, + 0xddf4a20bu, 0x90217921u, 0x0cbbcb9bu, 0x00000000u, 0x790f2266u, 0xd31ceb5cu, 0xa7b58b42u, 0x00000000u, + 0x21fdd340u, 0x35c8450eu, 0xdab84239u, 0x00000000u, 0xfaafaf58u, 0xc0bd647bu, 0x191bc271u, 0x00000000u, + 0x9e839693u, 0xd447d632u, 0xa3e3cd34u, 0x00000000u, 0x9816acb2u, 0x77a4c5f5u, 0x3eaeccfbu, 0x00000000u, + 0x47e04ba9u, 0xbee48e8du, 0x11df34c8u, 0x00000000u, 0x15a08a3cu, 0x658be5c3u, 0xc6403f48u, 0x00000000u, + 0xa8337739u, 0x97094582u, 0x88bce4acu, 0x00000000u, 0x1c5a2203u, 0x54f080bcu, 0x145a7a01u, 0x00000000u, + 0xc216a0ffu, 0xc036cf58u, 0x42127f23u, 0x00000000u, 0x4afdd8fau, 0x5144b748u, 0xe3a9493du, 0x00000000u, + 0x7d1010ddu, 0xc31737aeu, 0x72e658f1u, 0x00000000u, 0xb2bc988bu, 0x874068abu, 0x4752b9ecu, 0x00000000u, + 0xe055263eu, 0xb57d6353u, 0xc4f356bdu, 0x00000000u, 0xf2b9ce80u, 0x3faf6989u, 0x1770771eu, 0x00000000u, + 0x950fc854u, 0x537f6518u, 0x6f8f1b03u, 0x00000000u, 0x3c137b49u, 0x660207d5u, 0x64ac0a72u, 0x00000000u, + 0x59be07efu, 0xbe09834bu, 0x97b811efu, 0x00000000u, 0x7967f639u, 0x1cdaeda5u, 0x921b66a8u, 0x00000000u, + 0x2cce2e38u, 0x506c746au, 0x6a374c25u, 0x00000000u, 0x242b888du, 0x63b59666u, 0x4455c37cu, 0x00000000u, + 0xd98a0ed3u, 0xdc14021au, 0x012b5d82u, 0x00000000u, 0x9a37ff7fu, 0xa3fb2747u, 0x60c3dd9du, 0x00000000u, + 0x7818642eu, 0xca374746u, 0x60c22570u, 0x00000000u, 0x10804844u, 0x5f5ca629u, 0x40ff019fu, 0x00000000u, + 0x61fa17b2u, 0x3ae80a51u, 0x265e1089u, 0x00000000u, 0xfc40da19u, 0x20fd6d3au, 0xb4c2e06fu, 0x00000000u, + 0xb7b31acdu, 0x9e273818u, 0xe955351fu, 0x00000000u, 0x0146b1d6u, 0x4d3790ceu, 0x2f2ef0b7u, 0x00000000u, + 0x93b16f10u, 0xa2b2d58cu, 0xe5dcdf1fu, 0x00000000u, 0x61354928u, 0x3c63db78u, 0xec9da3a4u, 0x00000000u, + 0xac48ee35u, 0xc3c4f767u, 0x71ea1e0bu, 0x00000000u, 0x7287c339u, 0x63988fb6u, 0xbfe036acu, 0x00000000u, + 0x35eae594u, 0xf9b41907u, 0x2d097146u, 0x00000000u, 0x7602d6deu, 0x508a8127u, 0xa47c939bu, 0x00000000u, + 0xae41d19eu, 0xeb2d9aadu, 0xca0a22dbu, 0x00000000u, 0x3fa92484u, 0x34e77d30u, 0xe2f5759du, 0x00000000u, + 0x7ce514bbu, 0x18f8b09du, 0xd3314b39u, 0x00000000u, 0xa600b305u, 0x068bd432u, 0xc86814d2u, 0x00000000u, + 0x9b7cfb72u, 0x9d56d54bu, 0xdd6c8907u, 0x00000000u, 0x7edb5e71u, 0x7615827du, 0x9e0a75a4u, 0x00000000u, + 0x32a1e232u, 0x26d36ecdu, 0xd801ced0u, 0x00000000u, 0x372fa45eu, 0x811cb66bu, 0x45181f97u, 0x00000000u, + 0x3aff4aa1u, 0x9908111eu, 0xcd679c4eu, 0x00000000u, 0x71206dc3u, 0x2383b298u, 0x3e95f804u, 0x00000000u, + 0x2a217f2du, 0xe1ffcadau, 0x51ccb6e1u, 0x00000000u, 0x5fb9577bu, 0x122f7d23u, 0x722f227fu, 0x00000000u, + 0xe9f6f5f2u, 0x68e22b74u, 0xa6b7e5eeu, 0x00000000u, 0x2e93d042u, 0x2497b6f1u, 0xbb4be878u, 0x00000000u, + 0x10d4106bu, 0x72ce2922u, 0x511385eau, 0x00000000u, 0x04296d0bu, 0x87fd229fu, 0xf6c99a1cu, 0x00000000u, + 0x11b3b25eu, 0xd0d5e251u, 0x8a07a0e6u, 0x00000000u, 0xb93b2f92u, 0x18b76f8du, 0xde7cce09u, 0x00000000u, + 0x02ec3339u, 0xe824852au, 0xa8660512u, 0x00000000u, 0x5665b9b3u, 0x01d16dd3u, 0x9c67c9b7u, 0x00000000u, + 0x16622051u, 0x9bdad41eu, 0xc5ecdbb8u, 0x00000000u, 0x446dc047u, 0x3d1cea2eu, 0x38d1dcddu, 0x00000000u, + 0x398f04ebu, 0x1d29069eu, 0x3fec755bu, 0x00000000u, 0xa8c8d0adu, 0x4d71c198u, 0xc7ea4e97u, 0x00000000u, }; -// 982 u32 values -static const uint32_t kCnnV3TestWeightsU32[982] = { +// 1238 u32 values +static const uint32_t kCnnV3TestWeightsU32[1238] = { 0xa8b23143u, 0x2f9432e3u, 0x3491b3cbu, 0x317e3104u, 0xa79fb324u, 0x3419acf6u, 0x32322d86u, 0xb13da859u, 0xb4302831u, 0x2d0e324au, 0xad9630f5u, 0x338c3485u, 0xb1dd3158u, 0xb461a51du, 0x2f07b2a3u, 0x347d30b3u, 0xacf9aeb0u, 0xb1f6a4adu, 0xa377b31bu, 0x2e85b13eu, 0x3263a8d4u, 0xaf352fb1u, 0x31da3261u, 0xb010ac52u, @@ -203,91 +203,123 @@ static const uint32_t kCnnV3TestWeightsU32[982] = { 0xa83f2c18u, 0xb41ca864u, 0x338c31d0u, 0xb22cb4b2u, 0x279a33c1u, 0xb1b5b2b8u, 0x30512e25u, 0x345a2ba3u, 0xafab9b4bu, 0xad64a2feu, 0xb45cb14bu, 0x300fadadu, 0xa8acb49fu, 0x2c3d2d88u, 0x31f63150u, 0xb3a03011u, 0x2bf1a3acu, 0xb464b0e3u, 0xa6eeb14fu, 0xb235aa9cu, 0x3416323bu, 0x3420b1bcu, 0x3414b4a1u, 0xb4af3457u, - 0x3484310du, 0x348533cbu, 0xb40d27bbu, 0x2c5f32b7u, 0xaa5b2c68u, 0xb2a72984u, + 0x3484310du, 0x348533cbu, 0xb40d27bbu, 0x2c5f32b7u, 0xaa5b2c68u, 0xb2a72984u, 0xb414309bu, 0x32b33069u, + 0x1e0aa43bu, 0x3482af36u, 0xad08307au, 0xb162b23eu, 0x3440a58bu, 0xb178307fu, 0xacad32e7u, 0xb0f632c1u, + 0x34192c8eu, 0x2f69b0a6u, 0xb2b534aeu, 0x2eb0b3e7u, 0xb41eae27u, 0x30dfa396u, 0xae56b020u, 0x222b32a3u, + 0xa81e3295u, 0x2dca3459u, 0x3365b360u, 0xb2e19e98u, 0x2f34b2abu, 0xb019b458u, 0xa886b2ebu, 0x22b8aa94u, + 0xb47eb03bu, 0xacd92c64u, 0xb3832dd0u, 0xb0d5b4abu, 0xac11a6adu, 0xacb131f5u, 0x2b2f24adu, 0x20a6b497u, + 0xaa0cadf5u, 0x316eb3adu, 0xb496343fu, 0x31112bc9u, 0x3185b022u, 0x341f2d15u, 0xb465349eu, 0x2738a83bu, + 0xae49b2c8u, 0xb4a534aeu, 0x3294a74bu, 0xa235aec3u, 0xa3b83497u, 0xb44eb316u, 0xb07f3447u, 0xb3dc18feu, + 0x3421a9ddu, 0x348615eeu, 0x1996b0a1u, 0xa7f332e7u, 0x32d3b03cu, 0x24b8ac3au, 0xb2053493u, 0xb480afa0u, + 0xb1c2ac27u, 0xb21e2eeau, 0xb08b2eb6u, 0xadcead8fu, 0xa5253029u, 0x32c5ad53u, 0xb17f2987u, 0xae0b33afu, + 0x9aa3b46du, 0xb105b338u, 0xb31730bfu, 0x343231e5u, 0x300a2c17u, 0x34bb301au, 0xb279ae16u, 0x251b21e3u, + 0x2c58b22fu, 0x341bb4aau, 0xb46cb085u, 0xb0fdb386u, 0xb47cb057u, 0xb1e5b03du, 0xac69aca9u, 0xae9cae2fu, + 0xb48fb3e1u, 0x30edb1b8u, 0x341d34b6u, 0x24e3192fu, 0x3142af1fu, 0x329c3115u, 0xa90b3398u, 0x31e23120u, + 0x341faf5bu, 0x34bfb3cau, 0xb3cf3130u, 0xb4792e00u, 0x31bf3130u, 0x32da2bddu, 0xb04db3b8u, 0xb464aa97u, + 0xb082a7f4u, 0xa9c1ac1eu, 0xb0693349u, 0xa9af338fu, 0x162cae9du, 0xb0a9aa51u, 0xb2af1696u, 0x290dadb0u, + 0x3238aaa6u, 0x3483b0acu, 0x347d3177u, 0xb2df327eu, 0xb2562410u, 0x2a77321cu, 0x3420b08bu, 0x28e8b363u, + 0xb43c303eu, 0x32112b84u, 0x1f86b427u, 0x2e42b0a3u, 0x3432b352u, 0xb2073394u, 0x2abbaec9u, 0xa8673030u, + 0xb39ab299u, 0xa6dc34ccu, 0xa16a3327u, 0xb3ea340eu, 0x3420b369u, 0xaf1d344cu, 0xa74ead90u, 0xb1f3aa70u, + 0xb0bd33a6u, 0xb4282fe2u, 0x2de7b46eu, 0x2df8ae2fu, 0x3452b3cbu, 0x333930c5u, 0xaee8b2fbu, 0x25b6ad0eu, + 0xb438afcdu, 0xb0b6ad09u, 0xb1d2ac61u, 0x2ce0b092u, 0xadf0ac4bu, 0x31382535u, 0x2ab9aca7u, 0x22c1347au, + 0x31a333deu, 0xa972b43cu, 0x34ac2f9eu, 0xb3d2a665u, 0xb32c28c3u, 0x1cb730d4u, 0x3317304au, 0x2c512cf4u, + 0x329330e3u, 0xb4733316u, 0xb1732851u, 0x2db332ebu, 0xb1fdaa20u, 0x2fd3ae2eu, 0xb3ceb1adu, 0x31133373u, + 0xaffab1c4u, 0x2fff3488u, 0xaf632c3eu, 0xb46cafb7u, 0xb4633063u, 0x3068b4c1u, 0x30ed344fu, 0xa049a45bu, + 0xaebca8e8u, 0xa94a22acu, 0x33a52b8au, 0xb40b34c1u, 0xb221ac6eu, 0xb015adaeu, 0x3112b240u, 0x3406988fu, + 0xb428b47du, 0xb408ab6eu, 0x34aab08eu, 0xb1ccb197u, 0x94eb29a8u, 0xacbc2a2du, 0xb2f03246u, 0x2f49a980u, + 0xad023312u, 0xb4232934u, 0xb423b254u, 0xb0123060u, 0xb42a304cu, 0x327132f6u, 0xb492b3e4u, 0x32cab442u, + 0x276ab118u, 0x31ada9aau, 0x0e7f9ed2u, 0xb2b834b2u, 0xb44e3259u, 0x336ba2deu, 0x2f1d2e58u, 0xaa41b08bu, + 0x2296ad20u, 0xaea6a5cdu, 0xb0c9af78u, 0xb2b9ad2fu, 0x2bd83325u, 0x2f72b308u, 0xb10a32adu, 0xb4b8b2b5u, + 0x3109b459u, 0xb45f34adu, 0xb41c30c3u, 0x30eb2b13u, 0xb4b2ad68u, 0x34b72b4fu, 0xb1f6b0a7u, 0x283eb338u, + 0x319d2b68u, 0x338930dcu, 0xb0da31dfu, 0xafc8284bu, 0x3426ae89u, 0x348e2efcu, 0x25c0aa62u, 0xb38a9febu, + 0x243fb10eu, 0x3424b427u, 0xb1ccb339u, 0xb3bd3118u, 0x305533afu, 0x2f5eb424u, 0x30f12d0eu, 0x3031324du, + 0xaed12a9eu, 0x34632f93u, 0x2e502ab9u, 0x30eba8d4u, 0xb28534c7u, 0x260fb1b7u, 0x297fa1b9u, 0xab5ab454u, + 0x2a8b2a5fu, 0x303a2e0bu, 0x31932d6fu, 0x25c32ccau, 0xb3a82c14u, 0x2435b05bu, 0x2ee03329u, 0x2b16b3ddu, + 0x307eb158u, 0x2b2d3249u, 0xae332b04u, 0x32fea821u, 0x2211304au, 0xb451ad0fu, }; // 256 uint16 values (raw f16 bits) static const uint16_t kCnnV3ExpectedEnc0U16[256] = { - 0x0000u, 0x0000u, 0x350cu, 0x3b3cu, 0x19bcu, 0x0000u, 0x0000u, 0x3d10u, - 0x31e9u, 0x0000u, 0x35d0u, 0x39c3u, 0x0000u, 0x0000u, 0x2c6fu, 0x35fbu, - 0x39b9u, 0x0000u, 0x0000u, 0x3538u, 0x2ebbu, 0x0000u, 0x34f8u, 0x0000u, - 0x0000u, 0x0000u, 0x0000u, 0x3c96u, 0x0000u, 0x3029u, 0x0000u, 0x0000u, - 0x0000u, 0x0000u, 0x0000u, 0x405au, 0x0000u, 0x367eu, 0x0000u, 0x3d2fu, - 0x383bu, 0x0000u, 0x342cu, 0x3f97u, 0x0000u, 0x3c3cu, 0x0000u, 0x424eu, - 0x0000u, 0x0000u, 0x0000u, 0x3a3au, 0x0000u, 0x3d8fu, 0x0000u, 0x3fd4u, - 0x307du, 0x0000u, 0x0000u, 0x3f68u, 0x0000u, 0x0000u, 0x0000u, 0x3c81u, - 0x0000u, 0x0000u, 0x398fu, 0x3ffeu, 0x0000u, 0x0000u, 0x0000u, 0x3ec1u, - 0x0000u, 0x39b8u, 0x0000u, 0x3c61u, 0x0000u, 0x2e3au, 0x3699u, 0x41deu, - 0x0000u, 0x0000u, 0x0000u, 0x3d2cu, 0x329au, 0x0000u, 0x0000u, 0x41a9u, - 0x2d70u, 0x342fu, 0x0000u, 0x4066u, 0x2c77u, 0x0000u, 0x37b7u, 0x3842u, - 0x2b9au, 0x0000u, 0x3655u, 0x4001u, 0x340au, 0x0000u, 0x30f5u, 0x41a5u, - 0x0000u, 0x0000u, 0x0000u, 0x3d05u, 0x0000u, 0x0000u, 0x30a6u, 0x40a3u, - 0x0000u, 0x0000u, 0x0000u, 0x4263u, 0x0000u, 0x0000u, 0x0000u, 0x3e62u, - 0x0000u, 0x0000u, 0x0000u, 0x42d7u, 0x0000u, 0x0000u, 0x0000u, 0x3de8u, - 0x0000u, 0x0000u, 0x0000u, 0x3f4du, 0x0000u, 0x38d4u, 0x3a61u, 0x3fb7u, - 0x0000u, 0x0000u, 0x0000u, 0x404cu, 0x3811u, 0x31a4u, 0x0000u, 0x3edfu, - 0x0000u, 0x0000u, 0x0000u, 0x3f30u, 0x0000u, 0x0000u, 0x0000u, 0x3ec7u, - 0x27dau, 0x0000u, 0x0000u, 0x3efeu, 0x0000u, 0x3027u, 0x0000u, 0x39ceu, - 0x28e8u, 0x0000u, 0x0000u, 0x4121u, 0x0000u, 0x0000u, 0x0000u, 0x40eeu, - 0x3b70u, 0x3379u, 0x0000u, 0x40d3u, 0x0000u, 0x0000u, 0x0000u, 0x3d88u, - 0x329du, 0x0000u, 0x0000u, 0x3fafu, 0x35c0u, 0x0000u, 0x374cu, 0x40ceu, - 0x32b4u, 0x2c9au, 0x0000u, 0x4094u, 0x3105u, 0x31f4u, 0x34e9u, 0x3cd7u, - 0x0000u, 0x0000u, 0x344bu, 0x3cd1u, 0x0000u, 0x2d13u, 0x0000u, 0x3e7eu, - 0x0000u, 0x2eacu, 0x0000u, 0x4123u, 0x0000u, 0x36edu, 0x0000u, 0x3c69u, - 0x0000u, 0x0000u, 0x0000u, 0x41d5u, 0x0000u, 0x36e4u, 0x0000u, 0x4049u, - 0x0000u, 0x0000u, 0x0000u, 0x401du, 0x0000u, 0x38d1u, 0x333au, 0x3b08u, - 0x0000u, 0x0000u, 0x0000u, 0x3d12u, 0x0000u, 0x0000u, 0x0000u, 0x3e6eu, - 0x0000u, 0x0000u, 0x0000u, 0x4028u, 0x0000u, 0x0000u, 0x0000u, 0x3f64u, - 0x0000u, 0x0000u, 0x0000u, 0x3e4bu, 0x2eeau, 0x393cu, 0x0000u, 0x4007u, - 0x0000u, 0x267fu, 0x0000u, 0x3eabu, 0x35b4u, 0x38f9u, 0x0000u, 0x3e6bu, + 0x3c3fu, 0x0000u, 0x2aeeu, 0x3cdfu, 0x0000u, 0x0000u, 0x3a34u, 0x0000u, + 0x33e1u, 0x251du, 0x29e7u, 0x3dd0u, 0x0000u, 0x3996u, 0x2e7du, 0x3847u, + 0x259bu, 0x29a6u, 0x3a17u, 0x0000u, 0x3022u, 0x0000u, 0x3c4bu, 0x3c15u, + 0x0000u, 0x0000u, 0x38e0u, 0x3a98u, 0x0000u, 0x37dbu, 0x0000u, 0x0000u, + 0x0000u, 0x0000u, 0x0000u, 0x4027u, 0x0000u, 0x393cu, 0x0000u, 0x3c3bu, + 0x0000u, 0x31c4u, 0x3918u, 0x3f6fu, 0x0000u, 0x0000u, 0x0000u, 0x3c35u, + 0x0000u, 0x0000u, 0x0000u, 0x403eu, 0x0000u, 0x32b6u, 0x0000u, 0x4008u, + 0x3440u, 0x0000u, 0x0000u, 0x4003u, 0x0000u, 0x0000u, 0x0000u, 0x3d6bu, + 0x0000u, 0x0000u, 0x0000u, 0x4115u, 0x0000u, 0x0000u, 0x0000u, 0x3bcdu, + 0x30acu, 0x301eu, 0x3a8eu, 0x40e1u, 0x0000u, 0x0000u, 0x2dc0u, 0x401au, + 0x0000u, 0x0000u, 0x3638u, 0x3df2u, 0x0000u, 0x3c65u, 0x0000u, 0x3feau, + 0x2d79u, 0x0000u, 0x2e52u, 0x3f56u, 0x0000u, 0x0000u, 0x0000u, 0x3e3fu, + 0x34d0u, 0x0000u, 0x0000u, 0x3c46u, 0x38b0u, 0x3324u, 0x0000u, 0x4018u, + 0x0000u, 0x3385u, 0x0000u, 0x408du, 0x31ddu, 0x3585u, 0x40bau, 0x4009u, + 0x0000u, 0x2fd2u, 0x0000u, 0x4147u, 0x3baau, 0x0000u, 0x0000u, 0x3c42u, + 0x0000u, 0x0000u, 0x3378u, 0x3fc6u, 0x30cbu, 0x0000u, 0x3978u, 0x3440u, + 0x0000u, 0x0000u, 0x0000u, 0x38eeu, 0x0000u, 0x0000u, 0x0000u, 0x4117u, + 0x0000u, 0x0000u, 0x0000u, 0x4089u, 0x0000u, 0x3647u, 0x0000u, 0x43cfu, + 0x3752u, 0x2d2bu, 0x0000u, 0x3c2bu, 0x0000u, 0x3615u, 0x39cau, 0x0000u, + 0x0000u, 0x0000u, 0x0000u, 0x3e2du, 0x0000u, 0x0000u, 0x0000u, 0x3e18u, + 0x0000u, 0x0000u, 0x0000u, 0x3d99u, 0x2ca5u, 0x0000u, 0x0000u, 0x3d64u, + 0x0000u, 0x2b7fu, 0x0000u, 0x3f9eu, 0x0000u, 0x0000u, 0x0000u, 0x4133u, + 0x0000u, 0x0000u, 0x0000u, 0x3fc4u, 0x0000u, 0x0000u, 0x0000u, 0x3c91u, + 0x0000u, 0x2a5du, 0x0000u, 0x4166u, 0x0000u, 0x0000u, 0x0000u, 0x4089u, + 0x3165u, 0x0000u, 0x0000u, 0x3f6eu, 0x0000u, 0x0000u, 0x358du, 0x417fu, + 0x0000u, 0x356cu, 0x0000u, 0x4243u, 0x3c04u, 0x0000u, 0x0000u, 0x406bu, + 0x0000u, 0x315bu, 0x0000u, 0x40b7u, 0x0000u, 0x34beu, 0x0000u, 0x4108u, + 0x0000u, 0x390au, 0x2607u, 0x408fu, 0x0000u, 0x0000u, 0x0000u, 0x3b05u, + 0x3407u, 0x0000u, 0x0000u, 0x3d13u, 0x0000u, 0x33b5u, 0x0000u, 0x3dafu, + 0x0000u, 0x0000u, 0x0000u, 0x3d80u, 0x0000u, 0x2f2fu, 0x0000u, 0x3d4cu, + 0x0000u, 0x0000u, 0x0000u, 0x416eu, 0x0000u, 0x0000u, 0x0000u, 0x402au, + 0x0000u, 0x3b06u, 0x0000u, 0x3f77u, 0x0000u, 0x37fbu, 0x0000u, 0x4060u, }; // kCnnV3Dec1HW = (W/2) x (H/2) = 4 x 4 // 64 uint16 values (raw f16 bits) static const uint16_t kCnnV3ExpectedDec1U16[64] = { - 0x0000u, 0x2692u, 0x3823u, 0x397eu, 0x0000u, 0x22dcu, 0x35dcu, 0x35f9u, - 0x0000u, 0x3936u, 0x24b5u, 0x3434u, 0x0000u, 0x3b63u, 0x0000u, 0x32fcu, - 0x0000u, 0x2913u, 0x3523u, 0x33d6u, 0x0000u, 0x3023u, 0x2575u, 0x0000u, - 0x0000u, 0x39edu, 0x0000u, 0x0000u, 0x0000u, 0x3c91u, 0x0000u, 0x0000u, - 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, 0x0000u, - 0x0000u, 0x3754u, 0x0000u, 0x0000u, 0x318cu, 0x3a4du, 0x0000u, 0x0000u, - 0x3206u, 0x32deu, 0x0000u, 0x0000u, 0x317du, 0x3437u, 0x0000u, 0x0000u, - 0x312au, 0x357fu, 0x0000u, 0x0000u, 0x0000u, 0x39b5u, 0x0000u, 0x0000u, + 0x38dcu, 0x3d03u, 0x0000u, 0x39b0u, 0x3965u, 0x3dd1u, 0x30fdu, 0x3adau, + 0x387au, 0x3c79u, 0x3114u, 0x3c0eu, 0x0000u, 0x3a66u, 0x2ed6u, 0x3816u, + 0x3a16u, 0x3dbau, 0x0000u, 0x3a4du, 0x3cf6u, 0x3fccu, 0x0000u, 0x3c1cu, + 0x367bu, 0x3f06u, 0x0000u, 0x3b5cu, 0x0000u, 0x39ecu, 0x3660u, 0x3781u, + 0x3936u, 0x3accu, 0x0000u, 0x38dbu, 0x3d0fu, 0x3e45u, 0x0000u, 0x38bau, + 0x3905u, 0x3b8eu, 0x265du, 0x3c1eu, 0x0000u, 0x3881u, 0x2c6cu, 0x0000u, + 0x3905u, 0x3c23u, 0x0000u, 0x3271u, 0x3837u, 0x35e1u, 0x0000u, 0x0000u, + 0x3961u, 0x3c10u, 0x0000u, 0x0000u, 0x3594u, 0x3af9u, 0x382cu, 0x0000u, }; // 256 uint16 values (raw f16 bits) static const uint16_t kCnnV3ExpectedOutputU16[256] = { - 0x3800u, 0x3934u, 0x3800u, 0x38aau, 0x384au, 0x3800u, 0x3800u, 0x3917u, - 0x38d5u, 0x3800u, 0x3800u, 0x38f2u, 0x3800u, 0x38c9u, 0x3800u, 0x38d4u, - 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x38dau, 0x3800u, 0x3800u, - 0x3800u, 0x383eu, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, - 0x396du, 0x38eeu, 0x3800u, 0x3a87u, 0x3899u, 0x3800u, 0x3800u, 0x3972u, - 0x3a4au, 0x3800u, 0x3800u, 0x3847u, 0x386du, 0x3800u, 0x3800u, 0x3a70u, - 0x3800u, 0x381fu, 0x3800u, 0x3800u, 0x3800u, 0x3945u, 0x3800u, 0x392eu, - 0x3800u, 0x3800u, 0x3800u, 0x3844u, 0x3800u, 0x3800u, 0x3820u, 0x3800u, - 0x3a6du, 0x3832u, 0x3800u, 0x3ab0u, 0x3909u, 0x3800u, 0x3800u, 0x3a12u, - 0x3873u, 0x3800u, 0x3800u, 0x39b8u, 0x3a9au, 0x3800u, 0x3800u, 0x3a41u, - 0x3800u, 0x3800u, 0x3800u, 0x38d0u, 0x3952u, 0x3800u, 0x3800u, 0x398cu, - 0x3800u, 0x3800u, 0x3800u, 0x3a21u, 0x3800u, 0x3800u, 0x3800u, 0x3800u, - 0x3950u, 0x3800u, 0x3800u, 0x3abdu, 0x39ccu, 0x3800u, 0x3800u, 0x39e0u, - 0x3800u, 0x3800u, 0x3800u, 0x3a62u, 0x38d7u, 0x3800u, 0x3800u, 0x3a23u, - 0x3858u, 0x3800u, 0x3800u, 0x39f8u, 0x3800u, 0x3800u, 0x3800u, 0x3a01u, - 0x38e7u, 0x3800u, 0x3800u, 0x3822u, 0x38fcu, 0x3800u, 0x3832u, 0x3800u, - 0x3840u, 0x383au, 0x3800u, 0x3b39u, 0x390du, 0x3800u, 0x3800u, 0x399bu, - 0x3800u, 0x3800u, 0x3800u, 0x39c2u, 0x3802u, 0x3800u, 0x3800u, 0x3a41u, - 0x398bu, 0x3800u, 0x3800u, 0x39fau, 0x3800u, 0x3800u, 0x3800u, 0x396au, - 0x38d3u, 0x3800u, 0x3800u, 0x3888u, 0x3909u, 0x3800u, 0x3800u, 0x3800u, - 0x3863u, 0x3800u, 0x3800u, 0x3ae8u, 0x3a06u, 0x3800u, 0x3800u, 0x3a7du, - 0x38c1u, 0x3800u, 0x3800u, 0x3a20u, 0x38cdu, 0x3800u, 0x3800u, 0x390cu, - 0x3820u, 0x3800u, 0x3800u, 0x39d5u, 0x3863u, 0x3800u, 0x3800u, 0x389cu, - 0x3800u, 0x3800u, 0x3800u, 0x38bcu, 0x3887u, 0x3800u, 0x3866u, 0x3800u, - 0x38bbu, 0x3800u, 0x3800u, 0x3a8du, 0x394cu, 0x3800u, 0x3800u, 0x39b9u, - 0x394au, 0x3800u, 0x3800u, 0x3977u, 0x3800u, 0x3800u, 0x3800u, 0x3906u, - 0x3800u, 0x3800u, 0x386bu, 0x3a02u, 0x38bbu, 0x3800u, 0x3800u, 0x39d7u, - 0x38a2u, 0x3800u, 0x3800u, 0x3800u, 0x3899u, 0x3800u, 0x3811u, 0x3800u, - 0x3830u, 0x3800u, 0x387au, 0x3918u, 0x386au, 0x3800u, 0x38acu, 0x39f0u, - 0x39c7u, 0x3800u, 0x38beu, 0x3988u, 0x38c3u, 0x3800u, 0x3930u, 0x39d5u, - 0x397bu, 0x3800u, 0x3918u, 0x3a09u, 0x394cu, 0x3800u, 0x3952u, 0x3961u, - 0x3980u, 0x3800u, 0x392eu, 0x3872u, 0x39c2u, 0x3800u, 0x3903u, 0x3800u, + 0x3988u, 0x391du, 0x3800u, 0x390au, 0x3800u, 0x39e6u, 0x3800u, 0x3836u, + 0x3959u, 0x39e8u, 0x3800u, 0x3817u, 0x38c4u, 0x39cbu, 0x3800u, 0x392au, + 0x3837u, 0x3961u, 0x3800u, 0x3884u, 0x38a4u, 0x391fu, 0x3800u, 0x3800u, + 0x3943u, 0x38e9u, 0x3800u, 0x3800u, 0x3920u, 0x397fu, 0x3800u, 0x3800u, + 0x3a53u, 0x3800u, 0x3800u, 0x39deu, 0x393cu, 0x3956u, 0x3800u, 0x3b15u, + 0x3960u, 0x383cu, 0x3800u, 0x3aa5u, 0x38b9u, 0x3966u, 0x3800u, 0x3a4bu, + 0x38eau, 0x392au, 0x3800u, 0x3b2fu, 0x38c2u, 0x3800u, 0x3800u, 0x3aafu, + 0x3a59u, 0x3879u, 0x3800u, 0x3a5bu, 0x3924u, 0x3933u, 0x3800u, 0x38c0u, + 0x393bu, 0x3800u, 0x3800u, 0x3a0bu, 0x38ecu, 0x385cu, 0x3800u, 0x3b25u, + 0x3968u, 0x384bu, 0x3800u, 0x39dbu, 0x3800u, 0x3972u, 0x3800u, 0x3b7cu, + 0x38b9u, 0x3800u, 0x3800u, 0x3b3fu, 0x388eu, 0x3898u, 0x3800u, 0x39d2u, + 0x38fau, 0x3800u, 0x3800u, 0x391eu, 0x3872u, 0x3966u, 0x3800u, 0x38c1u, + 0x38c5u, 0x3800u, 0x3800u, 0x3a4au, 0x3a61u, 0x3800u, 0x3800u, 0x3b9cu, + 0x38edu, 0x3800u, 0x3800u, 0x3b9du, 0x3844u, 0x38a2u, 0x3800u, 0x3b5au, + 0x3800u, 0x38edu, 0x3800u, 0x3a57u, 0x3800u, 0x3828u, 0x3800u, 0x3ad7u, + 0x3810u, 0x3800u, 0x3800u, 0x3aa6u, 0x38ceu, 0x38e7u, 0x3800u, 0x3800u, + 0x3921u, 0x3800u, 0x3800u, 0x3a61u, 0x3a11u, 0x3800u, 0x3800u, 0x3b23u, + 0x3994u, 0x3800u, 0x3800u, 0x3b95u, 0x3995u, 0x3800u, 0x3800u, 0x3b83u, + 0x38c6u, 0x3a05u, 0x3800u, 0x3b7cu, 0x3887u, 0x385au, 0x3800u, 0x3b0bu, + 0x38efu, 0x3800u, 0x3800u, 0x398eu, 0x39edu, 0x38d8u, 0x3800u, 0x381bu, + 0x3932u, 0x3800u, 0x3800u, 0x3a29u, 0x3992u, 0x3800u, 0x3800u, 0x3ac4u, + 0x394du, 0x3800u, 0x3800u, 0x3b3bu, 0x384bu, 0x3800u, 0x3800u, 0x3b07u, + 0x3991u, 0x384cu, 0x3800u, 0x3b38u, 0x392eu, 0x3834u, 0x3800u, 0x3ab9u, + 0x397fu, 0x3800u, 0x3800u, 0x3948u, 0x38d1u, 0x3800u, 0x3800u, 0x3825u, + 0x3938u, 0x3800u, 0x3800u, 0x39a1u, 0x3991u, 0x3800u, 0x3800u, 0x3ac0u, + 0x3998u, 0x3800u, 0x3800u, 0x3adfu, 0x3973u, 0x3800u, 0x3800u, 0x3b7bu, + 0x39fdu, 0x3800u, 0x3800u, 0x3b0du, 0x3991u, 0x3800u, 0x3800u, 0x3a5du, + 0x38b6u, 0x3800u, 0x3800u, 0x39cau, 0x38acu, 0x3840u, 0x3800u, 0x3825u, + 0x3813u, 0x3800u, 0x3800u, 0x398fu, 0x3800u, 0x3800u, 0x3800u, 0x3a33u, + 0x3800u, 0x3800u, 0x3800u, 0x398eu, 0x3845u, 0x3800u, 0x3800u, 0x3a2du, + 0x384fu, 0x3800u, 0x3800u, 0x3a2eu, 0x3800u, 0x3800u, 0x3800u, 0x3a3fu, + 0x3834u, 0x3800u, 0x3800u, 0x39ebu, 0x387eu, 0x3839u, 0x393au, 0x3989u, }; diff --git a/cnn_v3/tools/shaders.js b/cnn_v3/tools/shaders.js index 6c49864..36f53c8 100644 --- a/cnn_v3/tools/shaders.js +++ b/cnn_v3/tools/shaders.js @@ -1,9 +1,10 @@ 'use strict'; // CNN v3 WGSL shaders — matches cnn_v3/shaders/*.wgsl exactly. -// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672, total=1964 +// Weight offsets (f16 index): enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184, total=2476 +// BN is now Conv(8→8, 3×3, dilation=2): 8*8*9+8=584 weights (was 72 for 1×1) -const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1092, DEC0_OFF=1672; -const TOTAL_F16=1964, TOTAL_U32=982; +const ENC0_OFF=0, ENC1_OFF=724, BN_OFF=1020, DEC1_OFF=1604, DEC0_OFF=2184; +const TOTAL_F16=2476, TOTAL_U32=1238; // Inlined helpers — prepended to shaders that need them. const H = ` @@ -108,7 +109,7 @@ fn main(@builtin(global_invocation_id) id:vec3u){ pack2x16float(vec2f(o[4],o[5])),pack2x16float(vec2f(o[6],o[7])))); }`; -// Bottleneck: AvgPool(enc1) + Conv(8→8, 1×1) + ReLU → rgba32uint quarter-res (no FiLM) +// Bottleneck: AvgPool(enc1) + Conv(8→8, 3×3, dilation=2) + ReLU → rgba32uint quarter-res (no FiLM) // Params (16 bytes): wo u32 _pad×3 const BN_SHADER=H+` struct P{wo:u32,_a:u32,_b:u32,_c:u32} @@ -129,10 +130,13 @@ fn avg(qc:vec2i,hd:vec2i)->array{ fn main(@builtin(global_invocation_id) id:vec3u){ let hd=vec2i(textureDimensions(e1)); let qd=hd/2; let c=vec2i(id.xy); if(c.x>=qd.x||c.y>=qd.y){return;} - let ft=avg(c,hd); var o:array; + var o:array; for(var oc:u32=0u;oc<8u;oc++){ - var s=get_w(p.wo,64u+oc); - for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*8u+i)*ft[i];} + var s=get_w(p.wo,576u+oc); + for(var ky:i32=-1;ky<=1;ky++){for(var kx:i32=-1;kx<=1;kx++){ + let ft=avg(c+vec2i(kx,ky)*2,hd); let ki=u32(ky+1)*3u+u32(kx+1); + for(var i:u32=0u;i<8u;i++){s+=get_w(p.wo,oc*72u+i*9u+ki)*ft[i];} + }} o[oc]=max(0.,s); } textureStore(out,c,vec4u(pack2x16float(vec2f(o[0],o[1])),pack2x16float(vec2f(o[2],o[3])), diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index edf76e2..78f5f25 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -15,8 +15,8 @@ Outputs /cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). - Layout: enc0 (724) | enc1 (296) | bottleneck (72) | dec1 (580) | dec0 (292) - = 1964 f16 values = 982 u32 = 3928 bytes. + Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292) + = 2476 f16 values = 1238 u32 = 4952 bytes. /cnn_v3_film_mlp.bin FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). @@ -48,13 +48,13 @@ from train_cnn_v3 import CNNv3 # cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) # cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS -# = 1964 +# = 2476 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: diff --git a/cnn_v3/training/gen_test_vectors.py b/cnn_v3/training/gen_test_vectors.py index 640971c..2eb889c 100644 --- a/cnn_v3/training/gen_test_vectors.py +++ b/cnn_v3/training/gen_test_vectors.py @@ -23,7 +23,7 @@ DEC0_IN, DEC0_OUT = 8, 4 ENC0_WEIGHTS = ENC0_IN * ENC0_OUT * 9 + ENC0_OUT # 724 ENC1_WEIGHTS = ENC1_IN * ENC1_OUT * 9 + ENC1_OUT # 296 -BN_WEIGHTS = BN_IN * BN_OUT * 1 + BN_OUT # 72 +BN_WEIGHTS = BN_IN * BN_OUT * 9 + BN_OUT # 584 (3x3 dilation=2) DEC1_WEIGHTS = DEC1_IN * DEC1_OUT * 9 + DEC1_OUT # 580 DEC0_WEIGHTS = DEC0_IN * DEC0_OUT * 9 + DEC0_OUT # 292 @@ -32,30 +32,8 @@ ENC1_OFFSET = ENC0_OFFSET + ENC0_WEIGHTS BN_OFFSET = ENC1_OFFSET + ENC1_WEIGHTS DEC1_OFFSET = BN_OFFSET + BN_WEIGHTS DEC0_OFFSET = DEC1_OFFSET + DEC1_WEIGHTS -TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS # 1964 + 292 = 2256? let me check -# 724 + 296 + 72 + 580 + 292 = 1964 ... actually let me recount -# ENC0: 20*4*9 + 4 = 720+4 = 724 -# ENC1: 4*8*9 + 8 = 288+8 = 296 -# BN: 8*8*1 + 8 = 64+8 = 72 -# DEC1: 16*4*9 + 4 = 576+4 = 580 -# DEC0: 8*4*9 + 4 = 288+4 = 292 -# Total = 724+296+72+580+292 = 1964 ... but HOWTO.md says 2064. Let me recheck. -# DEC1: 16*4*9 = 576 ... but the shader says Conv(16->4) which is IN=16, OUT=4 -# weight idx: o * DEC1_IN * 9 + i * 9 + ki where o4) = OUT*IN*K^2 = 4*16*9 = 576 + bias 4 = 580. HOWTO says 576+4=580 OK. -# Total = 724+296+72+580+292 = let me sum: 724+296=1020, +72=1092, +580=1672, +292=1964. -# Hmm, HOWTO.md says 2064. Let me recheck HOWTO weight table: -# enc0: 20*4*9=720 +4 = 724 -# enc1: 4*8*9=288 +8 = 296 -# bottleneck: 8*8*1=64 +8 = 72 -# dec1: 16*4*9=576 +4 = 580 -# dec0: 8*4*9=288 +4 = 292 -# Total = 724+296+72+580+292 = 1964 -# The HOWTO says 2064 but I get 1964... 100 difference. Possible typo in doc. -# I'll use the correct value derived from the formulas: 1964. +TOTAL_F16 = DEC0_OFFSET + DEC0_WEIGHTS +# 724 + 296 + 584 + 580 + 292 = 2476 (BN is now 3x3 dilation=2, was 72) # --------------------------------------------------------------------------- # Helpers @@ -140,35 +118,41 @@ def enc1_forward(enc0, w, gamma_lo, gamma_hi, beta_lo, beta_hi): def bottleneck_forward(enc1, w): """ - AvgPool2x2(enc1, clamp-border) + Conv(8->8, 1x1) + ReLU + AvgPool2x2(enc1, clamp-border) + Conv(8->8, 3x3, dilation=2) + ReLU → rgba32uint (f16, quarter-res). No FiLM. enc1: (hH, hW, 8) f32 — half-res + Matches cnn_v3_bottleneck.wgsl exactly. """ hH, hW = enc1.shape[:2] qH, qW = hH // 2, hW // 2 wo = BN_OFFSET - # AvgPool2x2 with clamp (matches load_enc1_avg in WGSL) - avg = np.zeros((qH, qW, BN_IN), dtype=np.float32) - for qy in range(qH): - for qx in range(qW): - s = np.zeros(BN_IN, dtype=np.float32) - for dy in range(2): - for dx in range(2): - hy = min(qy * 2 + dy, hH - 1) - hx = min(qx * 2 + dx, hW - 1) - s += enc1[hy, hx, :] - avg[qy, qx, :] = s * 0.25 - - # 1x1 conv (no spatial loop, just channel dot-product) + def load_enc1_avg(qy, qx): + """Avg-pool 2x2 from enc1 at quarter-res coord. Zero for OOB (matches WGSL).""" + if qy < 0 or qx < 0 or qy >= qH or qx >= qW: + return np.zeros(BN_IN, dtype=np.float32) + s = np.zeros(BN_IN, dtype=np.float32) + for dy in range(2): + for dx in range(2): + hy = min(qy * 2 + dy, hH - 1) + hx = min(qx * 2 + dx, hW - 1) + s += enc1[hy, hx, :] + return s * 0.25 + + # 3x3 conv with dilation=2 in quarter-res space out = np.zeros((qH, qW, BN_OUT), dtype=np.float32) for o in range(BN_OUT): - bias = get_w(w, wo, BN_OUT * BN_IN + o) - s = np.full((qH, qW), bias, dtype=np.float32) - for i in range(BN_IN): - wv = get_w(w, wo, o * BN_IN + i) - s += wv * avg[:, :, i] - out[:, :, o] = np.maximum(0.0, s) + bias = get_w(w, wo, BN_OUT * BN_IN * 9 + o) + for qy in range(qH): + for qx in range(qW): + s = bias + for ky in range(-1, 2): + for kx in range(-1, 2): + feat = load_enc1_avg(qy + ky * 2, qx + kx * 2) # dilation=2 + ki = (ky + 1) * 3 + (kx + 1) + for i in range(BN_IN): + s += get_w(w, wo, o * BN_IN * 9 + i * 9 + ki) * feat[i] + out[qy, qx, o] = max(0.0, s) return np.float16(out).astype(np.float32) # pack2x16float boundary diff --git a/cnn_v3/training/train_cnn_v3.py b/cnn_v3/training/train_cnn_v3.py index 31cfd9d..c790495 100644 --- a/cnn_v3/training/train_cnn_v3.py +++ b/cnn_v3/training/train_cnn_v3.py @@ -6,17 +6,21 @@ """CNN v3 Training Script — U-Net + FiLM Architecture: - enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W - enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 - bottleneck Conv(8→8, 1×1) + ReLU H/4×W/4 - dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 - dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W + enc0 Conv(20→4, 3×3) + FiLM + ReLU H×W + enc1 Conv(4→8, 3×3) + FiLM + ReLU + pool2 H/2×W/2 + bottleneck Conv(8→8, 3×3, dilation=2) + ReLU H/4×W/4 + dec1 upsample×2 + cat(enc1) Conv(16→4) + FiLM H/2×W/2 + dec0 upsample×2 + cat(enc0) Conv(8→4) + FiLM H×W output sigmoid → RGBA FiLM MLP: Linear(5→16) → ReLU → Linear(16→40) 40 = 2 × (γ+β) for enc0(4) enc1(8) dec1(4) dec0(4) -Weight budget: ~5.4 KB f16 (fits ≤6 KB target) +Weight budget: ~4.84 KB conv f16 (fits ≤6 KB target) + +Training improvements: + --edge-loss-weight Sobel edge loss alongside MSE (default 0.1) + --film-warmup-epochs Train U-Net only for N epochs before unfreezing FiLM MLP (default 50) """ import argparse @@ -56,7 +60,7 @@ class CNNv3(nn.Module): self.enc0 = nn.Conv2d(N_FEATURES, c0, 3, padding=1) self.enc1 = nn.Conv2d(c0, c1, 3, padding=1) - self.bottleneck = nn.Conv2d(c1, c1, 1) + self.bottleneck = nn.Conv2d(c1, c1, 3, padding=2, dilation=2) self.dec1 = nn.Conv2d(c1 * 2, c0, 3, padding=1) # +skip enc1 self.dec0 = nn.Conv2d(c0 * 2, 4, 3, padding=1) # +skip enc0 @@ -95,6 +99,24 @@ class CNNv3(nn.Module): return torch.sigmoid(x) +# --------------------------------------------------------------------------- +# Loss +# --------------------------------------------------------------------------- + +def sobel_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Gradient loss via Sobel filters. No VGG dependency. + pred, target: (B, C, H, W) in [0, 1]. Returns scalar on same device.""" + kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=pred.dtype, device=pred.device).view(1, 1, 3, 3) + B, C, H, W = pred.shape + p = pred.view(B * C, 1, H, W) + t = target.view(B * C, 1, H, W) + return (F.mse_loss(F.conv2d(p, kx, padding=1), F.conv2d(t, kx, padding=1)) + + F.mse_loss(F.conv2d(p, ky, padding=1), F.conv2d(t, ky, padding=1))) + + # --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- @@ -129,11 +151,20 @@ def train(args): print(f"Model: enc={enc_channels} film_cond_dim={args.film_cond_dim} " f"params={nparams} (~{nparams*2/1024:.1f} KB f16)") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Phase 1: freeze FiLM MLP so U-Net convolutions stabilise first. + film_warmup = args.film_warmup_epochs + if film_warmup > 0: + for p in model.film_mlp.parameters(): + p.requires_grad = False + print(f"FiLM MLP frozen for first {film_warmup} epochs (phase-1 warmup)") + + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) criterion = nn.MSELoss() ckpt_dir = Path(args.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) start_epoch = 1 + film_unfrozen = (film_warmup == 0) if args.resume: ckpt_path = Path(args.resume) @@ -168,6 +199,15 @@ def train(args): for epoch in range(start_epoch, args.epochs + 1): if interrupted: break + + # Phase 2: unfreeze FiLM MLP after warmup, rebuild optimizer at reduced LR. + if not film_unfrozen and epoch > film_warmup: + for p in model.film_mlp.parameters(): + p.requires_grad = True + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr * 0.1) + film_unfrozen = True + print(f"\nPhase 2: FiLM MLP unfrozen at epoch {epoch} (lr={args.lr*0.1:.2e})") + model.train() epoch_loss = 0.0 n_batches = 0 @@ -177,7 +217,10 @@ def train(args): break feat, cond, target = feat.to(device), cond.to(device), target.to(device) optimizer.zero_grad() - loss = criterion(model(feat, cond), target) + pred = model(feat, cond) + loss = criterion(pred, target) + if args.edge_loss_weight > 0.0: + loss = loss + args.edge_loss_weight * sobel_loss(pred, target) loss.backward() optimizer.step() epoch_loss += loss.item() @@ -215,6 +258,8 @@ def _checkpoint(model, optimizer, epoch, loss, args): 'enc_channels': [int(c) for c in args.enc_channels.split(',')], 'film_cond_dim': args.film_cond_dim, 'input_mode': args.input_mode, + 'edge_loss_weight': args.edge_loss_weight, + 'film_warmup_epochs': args.film_warmup_epochs, }, } @@ -266,6 +311,10 @@ def main(): help='Save checkpoint every N epochs (0=disable)') p.add_argument('--resume', default='', metavar='CKPT', help='Resume from checkpoint path; if path missing, use latest in --checkpoint-dir') + p.add_argument('--edge-loss-weight', type=float, default=0.1, + help='Weight for Sobel edge loss alongside MSE (default 0.1; 0=disable)') + p.add_argument('--film-warmup-epochs', type=int, default=50, + help='Epochs to train U-Net only before unfreezing FiLM MLP (default 50; 0=joint)') train(p.parse_args()) diff --git a/src/tests/gpu/test_cnn_v3_parity.cc b/src/tests/gpu/test_cnn_v3_parity.cc index 15fe818..1a7f169 100644 --- a/src/tests/gpu/test_cnn_v3_parity.cc +++ b/src/tests/gpu/test_cnn_v3_parity.cc @@ -190,8 +190,8 @@ static std::vector run_cnn_v3(WebGPUTestFixture& fixture, effect.upload_weights(ctx.queue, weights_u32, weights_bytes); } else { // Explicitly zero weights to override any asset-loaded defaults. - // kWeightsBufBytes = ((1964+1)/2)*4 = 3928 - const uint32_t zero_size = ((1964u + 1u) / 2u) * 4u; + // kWeightsBufBytes = ((2476+1)/2)*4 = 4952 + const uint32_t zero_size = ((2476u + 1u) / 2u) * 4u; std::vector zeros(zero_size, 0); effect.upload_weights(ctx.queue, zeros.data(), zero_size); } -- cgit v1.2.3