diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-25 10:05:42 +0100 |
| commit | ce6e5b99f26e4e7c69a3cacf360bd0d492de928c (patch) | |
| tree | a8d64b33a7ea1109b6b7e1043ced946cac416756 /cnn_v3/docs | |
| parent | 8b4d7a49f038d7e849e6764dcc3abd1e1be01061 (diff) | |
feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG
- Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2):
effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights)
- Add Sobel edge loss in training (--edge-loss-weight, default 0.1)
- Add FiLM 2-phase training: freeze MLP for warmup epochs then
unfreeze at lr×0.1 (--film-warmup-epochs, default 50)
- Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B)
- Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts
- Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04
- Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py)
- Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed
handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed.
Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3).
Diffstat (limited to 'cnn_v3/docs')
| -rw-r--r-- | cnn_v3/docs/CNN_V3.md | 38 | ||||
| -rw-r--r-- | cnn_v3/docs/HOWTO.md | 6 | ||||
| -rw-r--r-- | cnn_v3/docs/HOW_TO_CNN.md | 41 | ||||
| -rw-r--r-- | cnn_v3/docs/cnn_v3_architecture.png | bin | 0 -> 254783 bytes | |||
| -rw-r--r-- | cnn_v3/docs/gen_architecture_png.py | 238 |
5 files changed, 264 insertions, 59 deletions
diff --git a/cnn_v3/docs/CNN_V3.md b/cnn_v3/docs/CNN_V3.md index 4d58811..d775e2b 100644 --- a/cnn_v3/docs/CNN_V3.md +++ b/cnn_v3/docs/CNN_V3.md @@ -27,33 +27,7 @@ CNN v3 is a next-generation post-processing effect using: ### Pipeline Overview -``` -G-Buffer (albedo, normal, depth, matID, UV) - │ - ▼ - FiLM Conditioning - (beat_time, audio_intensity, style_params) - │ → γ[], β[] per channel - ▼ - U-Net - ┌─────────────────────────────────────────┐ - │ Encoder │ - │ enc0 (H×W, 4ch) ────────────skip──────┤ - │ ↓ down (avg pool 2×2) │ - │ enc1 (H/2×W/2, 8ch) ────────skip──────┤ - │ ↓ down │ - │ bottleneck (H/4×W/4, 8ch) │ - │ │ - │ Decoder │ - │ ↑ up (nearest ×2) + skip enc1 │ - │ dec1 (H/2×W/2, 4ch) │ - │ ↑ up + skip enc0 │ - │ dec0 (H×W, 4ch) │ - └─────────────────────────────────────────┘ - │ - ▼ - output RGBA (H×W) -``` + FiLM is applied **inside each encoder/decoder block**, after each convolution. @@ -352,11 +326,11 @@ All f16, little-endian, same packing as v2 (`pack2x16float`). |-----------|---------|------|-----------| | enc0: Conv(20→4, 3×3) | 20×4×9=720 | +4 | 724 | | enc1: Conv(4→8, 3×3) | 4×8×9=288 | +8 | 296 | -| bottleneck: Conv(8→8, 1×1) | 8×8×1=64 | +8 | 72 | +| bottleneck: Conv(8→8, 3×3, dil=2) | 8×8×9=576 | +8 | 584 | | dec1: Conv(16→4, 3×3) | 16×4×9=576 | +4 | 580 | | dec0: Conv(8→4, 3×3) | 8×4×9=288 | +4 | 292 | | FiLM MLP (5→16→40) | 5×16+16×40=720 | +16+40 | 776 | -| **Total** | | | **~3.9 KB f16** | +| **Total conv** | | | **~4.84 KB f16** | Skip connections: dec1 input = 8ch (bottleneck) + 8ch (enc1 skip) = 16ch. dec0 input = 4ch (dec1) + 4ch (enc0 skip) = 8ch. @@ -541,7 +515,7 @@ class CNNv3(nn.Module): nn.Conv2d(enc_channels[0], enc_channels[1], 3, padding=1), ]) # Bottleneck - self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 1) + self.bottleneck = nn.Conv2d(enc_channels[1], enc_channels[1], 3, padding=2, dilation=2) # Decoder (skip connections: concat → double channels) self.dec = nn.ModuleList([ nn.Conv2d(enc_channels[1]*2, enc_channels[0], 3, padding=1), @@ -709,7 +683,7 @@ Parity results: Pass 0: pack_gbuffer.wgsl — assemble G-buffer channels into storage texture Pass 1: cnn_v3_enc0.wgsl — encoder level 0 (20→4ch, 3×3) Pass 2: cnn_v3_enc1.wgsl — encoder level 1 (4→8ch, 3×3) + downsample -Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 1×1) +Pass 3: cnn_v3_bottleneck.wgsl — bottleneck (8→8, 3×3, dilation=2) Pass 4: cnn_v3_dec1.wgsl — decoder level 1: upsample + skip + (16→4, 3×3) Pass 5: cnn_v3_dec0.wgsl — decoder level 0: upsample + skip + (8→4, 3×3) Pass 6: cnn_v3_output.wgsl — sigmoid + composite to framebuffer @@ -816,7 +790,7 @@ Status bar shows which channels are loaded. | `PACK_SHADER` | `STATIC_SHADER` | 20ch into feat_tex0 + feat_tex1 (rgba32uint each) | | `ENC0_SHADER` | part of `CNN_SHADER` | Conv(20→4, 3×3) + FiLM + ReLU; writes enc0_tex | | `ENC1_SHADER` | | Conv(4→8, 3×3) + FiLM + ReLU + avg_pool2×2; writes enc1_tex (half-res) | -| `BOTTLENECK_SHADER` | | Conv(8→8, 1×1) + FiLM + ReLU; writes bn_tex | +| `BOTTLENECK_SHADER` | | Conv(8→8, 3×3, dilation=2) + ReLU; writes bn_tex | | `DEC1_SHADER` | | nearest upsample×2 + concat(bn, enc1_skip) + Conv(16→4, 3×3) + FiLM + ReLU | | `DEC0_SHADER` | | nearest upsample×2 + concat(dec1, enc0_skip) + Conv(8→4, 3×3) + FiLM + ReLU | | `OUTPUT_SHADER` | | Conv(4→4, 1×1) + sigmoid → composites to canvas | diff --git a/cnn_v3/docs/HOWTO.md b/cnn_v3/docs/HOWTO.md index 1aead68..9a3efdf 100644 --- a/cnn_v3/docs/HOWTO.md +++ b/cnn_v3/docs/HOWTO.md @@ -439,10 +439,10 @@ FiLM γ/β are computed CPU-side by the FiLM MLP (Phase 4) and uploaded each fra |-------|---------|------|-----------| | enc0 | 20×4×9=720 | +4 | 724 | | enc1 | 4×8×9=288 | +8 | 296 | -| bottleneck | 8×8×1=64 | +8 | 72 | +| bottleneck | 8×8×9=576 | +8 | 584 | | dec1 | 16×4×9=576 | +4 | 580 | | dec0 | 8×4×9=288 | +4 | 292 | -| **Total** | | | **2064 f16 = ~4 KB** | +| **Total** | | | **2476 f16 = ~4.84 KB** | **Asset IDs** (registered in `workspaces/main/assets.txt` + `src/effects/shaders.cc`): `SHADER_CNN_V3_COMMON`, `SHADER_CNN_V3_ENC0`, `SHADER_CNN_V3_ENC1`, @@ -725,7 +725,7 @@ Both tools print first 8 pixels in the same format: **Expected delta:** ≤ 1/255 (≈ 4e-3) per channel, matching the parity test (`test_cnn_v3_parity`). Larger deltas indicate a weight mismatch — re-export -with `export_cnn_v3_weights.py` and verify the `.bin` size is 3928 bytes. +with `export_cnn_v3_weights.py` and verify the `.bin` size is 4952 bytes. --- diff --git a/cnn_v3/docs/HOW_TO_CNN.md b/cnn_v3/docs/HOW_TO_CNN.md index f5f1b1a..09db97c 100644 --- a/cnn_v3/docs/HOW_TO_CNN.md +++ b/cnn_v3/docs/HOW_TO_CNN.md @@ -28,26 +28,13 @@ CNN v3 is a 2-level U-Net with FiLM conditioning, designed to run in real-time a **Architecture:** -``` -Input: 20-channel G-buffer feature textures (rgba32uint) - │ - enc0 ──── Conv(20→4, 3×3) + FiLM + ReLU ┐ full res - │ ↘ skip │ - enc1 ──── AvgPool2×2 + Conv(4→8, 3×3) + FiLM ┐ ½ res - │ ↘ skip │ - bottleneck AvgPool2×2 + Conv(8→8, 1×1) + ReLU ¼ res (no FiLM) - │ │ - dec1 ←── upsample×2 + cat(enc1 skip) + Conv(16→4, 3×3) + FiLM - │ │ ½ res - dec0 ←── upsample×2 + cat(enc0 skip) + Conv(8→4, 3×3) + FiLM + sigmoid - full res → RGBA output -``` + **FiLM MLP:** `Linear(5→16) → ReLU → Linear(16→40)` trained jointly with U-Net. - Input: `[beat_phase, beat_norm, audio_intensity, style_p0, style_p1]` - Output: 40 γ/β values controlling style across all 4 FiLM layers -**Weight budget:** ~3.9 KB f16 (fits ≤6 KB target) +**Weight budget:** ~4.84 KB f16 conv (fits ≤6 KB target) **Two data paths:** - **Simple mode** — real photos with zeroed geometric channels (normal, depth, matid) @@ -307,7 +294,9 @@ uv run train_cnn_v3.py --input dataset/ --epochs 1 \ uv run train_cnn_v3.py \ --input dataset/ \ --input-mode simple \ - --epochs 200 + --epochs 200 \ + --edge-loss-weight 0.1 \ + --film-warmup-epochs 50 ``` **Blender G-buffer training:** @@ -315,7 +304,9 @@ uv run train_cnn_v3.py \ uv run train_cnn_v3.py \ --input dataset/ \ --input-mode full \ - --epochs 200 + --epochs 200 \ + --edge-loss-weight 0.1 \ + --film-warmup-epochs 50 ``` **Full-image mode (better global coherence, slower):** @@ -360,12 +351,14 @@ uv run train_cnn_v3.py \ | `--checkpoint-dir DIR` | `checkpoints/` | Set per-experiment | | `--checkpoint-every N` | 50 | 0 to disable intermediate checkpoints | | `--resume [CKPT]` | — | Resume from checkpoint path; if path missing, uses latest in `--checkpoint-dir` | +| `--edge-loss-weight F` | 0.1 | Sobel gradient loss weight alongside MSE; improves style/edge capture; 0=MSE only | +| `--film-warmup-epochs N` | 50 | Freeze FiLM MLP for first N epochs (phase-1), then unfreeze at lr×0.1; 0=joint training | ### Architecture at startup The model prints its parameter count: ``` -Model: enc=[4, 8] film_cond_dim=5 params=2740 (~5.4 KB f16) +Model: enc=[4, 8] film_cond_dim=5 params=3252 (~6.4 KB f16) ``` If `params` is much higher, `--enc-channels` was changed; update C++ constants accordingly. @@ -489,7 +482,7 @@ Use `--html-output PATH` to write to a different `weights.js` location. Output files are registered in `workspaces/main/assets.txt` as: ``` -WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 3928 bytes)" +WEIGHTS_CNN_V3, BINARY, weights/cnn_v3_weights.bin, "CNN v3 conv weights (f16, 4952 bytes)" WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP weights (f32, 3104 bytes)" ``` @@ -501,10 +494,10 @@ WEIGHTS_CNN_V3_FILM_MLP, BINARY, weights/cnn_v3_film_mlp.bin, "CNN v3 FiLM MLP w |-------|-----------|-------| | enc0 Conv(20→4,3×3)+bias | 724 | — | | enc1 Conv(4→8,3×3)+bias | 296 | — | -| bottleneck Conv(8→8,1×1)+bias | 72 | — | +| bottleneck Conv(8→8,3×3,dil=2)+bias | 584 | — | | dec1 Conv(16→4,3×3)+bias | 580 | — | | dec0 Conv(8→4,3×3)+bias | 292 | — | -| **Total** | **1964 f16** | **3928 bytes** | +| **Total** | **2476 f16** | **4952 bytes** | **`cnn_v3_film_mlp.bin`** — FiLM MLP weights as raw f32, row-major: @@ -534,8 +527,8 @@ Checkpoint: epoch=200 loss=0.012345 enc_channels=[4, 8] film_cond_dim=5 cnn_v3_weights.bin - 1964 f16 values → 982 u32 → 3928 bytes - Upload via CNNv3Effect::upload_weights(queue, data, 3928) + 2476 f16 values → 1238 u32 → 4952 bytes + Upload via CNNv3Effect::upload_weights(queue, data, 4952) cnn_v3_film_mlp.bin L0: weight (16, 5) + bias (16,) @@ -824,7 +817,7 @@ all geometric channels (normal, depth, depth_grad, mat_id, prev) = 0. ### Pitfalls - `rgba32uint` and `rgba16float` textures both need `STORAGE_BINDING | TEXTURE_BINDING` usage. -- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1092, dec0=1672). +- Weight offsets are **f16 indices** (enc0=0, enc1=724, bn=1020, dec1=1604, dec0=2184). - Uniform buffer layouts must match WGSL `Params` structs exactly (padding included). --- diff --git a/cnn_v3/docs/cnn_v3_architecture.png b/cnn_v3/docs/cnn_v3_architecture.png Binary files differnew file mode 100644 index 0000000..2116c2b --- /dev/null +++ b/cnn_v3/docs/cnn_v3_architecture.png diff --git a/cnn_v3/docs/gen_architecture_png.py b/cnn_v3/docs/gen_architecture_png.py new file mode 100644 index 0000000..bd60a97 --- /dev/null +++ b/cnn_v3/docs/gen_architecture_png.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["matplotlib"] +# /// +"""Generate CNN v3 U-Net + FiLM architecture diagram → cnn_v3_architecture.png""" + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +from matplotlib.path import Path +import matplotlib.patheffects as pe + +# --------------------------------------------------------------------------- +# Canvas +# --------------------------------------------------------------------------- +BG = '#0F172A' +fig = plt.figure(figsize=(17, 10), facecolor=BG) +ax = fig.add_axes([0, 0, 1, 1], facecolor=BG) +ax.set_xlim(0, 17) +ax.set_ylim(0, 10) +ax.axis('off') + +# --------------------------------------------------------------------------- +# Palette +# --------------------------------------------------------------------------- +C_ENC = '#3B82F6' # encoder — blue +C_BN = '#8B5CF6' # bottleneck — violet +C_DEC = '#10B981' # decoder — emerald +C_MLP = '#EC4899' # FiLM MLP — pink +C_FILM = '#F59E0B' # FiLM γ/β arrows — amber +C_IO = '#475569' # input/output — slate +C_SKP = '#F97316' # skip connections — orange +C_ARR = '#94A3B8' # main flow arrows — cool-grey +C_TXT = '#F1F5F9' # text — near-white +C_DIM = '#64748B' # dim labels — slate + +# --------------------------------------------------------------------------- +# Geometry — two-column U layout +# --------------------------------------------------------------------------- +EX, DX = 3.8, 13.2 # encoder / decoder centre-x +BX = 8.5 # bottleneck centre-x + +BW = 4.6 # block width (enc / dec) +BH = 0.95 # block height (enc / dec) +BW_BN = 5.4 # bottleneck wider +BH_BN = 0.95 +BH_IO = 0.72 + +# y positions (top = high number) +Y_IN = 8.90 +Y_E0 = 7.50 # enc0 full res +Y_E1 = 5.80 # enc1 ½ res +Y_BN = 3.20 # bottleneck ¼ res +Y_D1 = 5.80 # dec1 ½ res +Y_D0 = 7.50 # dec0 full res +Y_OUT = 8.90 + +Y_MLP = 1.25 # FiLM MLP + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def box(cx, cy, w, h, color, line1, line2='', lfs=9.5, sfs=8.0, alpha=0.92): + r = FancyBboxPatch((cx - w/2, cy - h/2), w, h, + boxstyle='round,pad=0.10', + fc=color, ec='white', lw=1.3, alpha=alpha, zorder=3) + ax.add_patch(r) + dy = 0.18 if line2 else 0 + ax.text(cx, cy + dy, line1, ha='center', va='center', + fontsize=lfs, fontweight='bold', color='white', zorder=4, + fontfamily='DejaVu Sans Mono') + if line2: + ax.text(cx, cy - 0.18, line2, ha='center', va='center', + fontsize=sfs, color='white', alpha=0.80, zorder=4) + + +def arrow(x0, y0, x1, y1, color=C_ARR, lw=1.8, dashed=False, + rad=0.0, label='', lx=None, ly=None): + ls = (0, (5, 3)) if dashed else 'solid' + cs = f'arc3,rad={rad}' if rad else 'arc3,rad=0' + ax.annotate('', xy=(x1, y1), xytext=(x0, y0), + arrowprops=dict(arrowstyle='->', color=color, lw=lw, + linestyle=ls, mutation_scale=13, + connectionstyle=cs), + zorder=2) + if label: + ax.text(lx if lx else (x0+x1)/2, + ly if ly else (y0+y1)/2, + label, ha='center', va='center', fontsize=7.5, + color=color, zorder=5, + bbox=dict(fc=BG, ec='none', alpha=0.75, + boxstyle='round,pad=0.15')) + + +def dim_label(x, y, txt): + ax.text(x, y, txt, ha='center', va='center', + fontsize=8.5, color=C_DIM, style='italic') + + +# --------------------------------------------------------------------------- +# Blocks +# --------------------------------------------------------------------------- + +box(EX, Y_IN, BW, BH_IO, C_IO, 'G-Buffer Features', + '20 channels · full res') + +box(EX, Y_E0, BW, BH, C_ENC, 'enc0 Conv(20→4, 3×3) + FiLM + ReLU', + 'full res · 4 ch') + +box(EX, Y_E1, BW, BH, C_ENC, 'enc1 Conv(4→8, 3×3) + FiLM + ReLU', + '½ res · 8 ch · (AvgPool↓ on input)') + +box(BX, Y_BN, BW_BN, BH_BN, C_BN, + 'bottleneck Conv(8→8, 3×3, dilation=2) + ReLU', + '¼ res · 8 ch · no FiLM · effective RF ≈ 10 px @ ½res') + +box(DX, Y_D1, BW, BH, C_DEC, 'dec1 Conv(16→4, 3×3) + FiLM + ReLU', + '½ res · 4 ch · (upsample↑ + cat enc1 skip)') + +box(DX, Y_D0, BW, BH, C_DEC, 'dec0 Conv(8→4, 3×3) + FiLM + sigmoid', + 'full res · 4 ch · (upsample↑ + cat enc0 skip)') + +box(DX, Y_OUT, BW, BH_IO, C_IO, 'RGBA Output', + '4 channels · full res') + +box(BX, Y_MLP, 9.2, 1.10, C_MLP, + 'FiLM MLP Linear(5→16) → ReLU → Linear(16→40)', + 'in: beat_phase · beat_norm · audio_intensity · style_p0 · style_p1' + ' → γ/β (×2) for enc0(4) enc1(8) dec1(4) dec0(4) = 40 values', + sfs=7.5) + +# --------------------------------------------------------------------------- +# Main-flow arrows +# --------------------------------------------------------------------------- + +# Input → enc0 +arrow(EX, Y_IN - BH_IO/2 - .04, EX, Y_E0 + BH/2 + .04) + +# enc0 → enc1 (AvgPool label beside) +arrow(EX, Y_E0 - BH/2 - .04, EX, Y_E1 + BH/2 + .04, + label='AvgPool\n 2×2', lx=EX + 0.72, ly=(Y_E0 + Y_E1)/2) + +# enc1 → bottleneck (curve down-right) +arrow(EX, Y_E1 - BH/2 - .04, + BX - BW_BN/2 - .04, Y_BN, + rad=-0.28, + label='AvgPool\n 2×2', lx=(EX + BX)/2 - 0.5, ly=Y_BN + 0.90) + +# bottleneck → dec1 (curve right-up) +arrow(BX + BW_BN/2 + .04, Y_BN, + DX, Y_D1 - BH/2 - .04, + rad=-0.28, + label='upsample\n 2×', lx=(BX + DX)/2 + 0.5, ly=Y_D1 - 0.90) + +# dec1 → dec0 +arrow(DX, Y_D1 + BH/2 + .04, DX, Y_D0 - BH/2 - .04, + label='upsample\n 2×', lx=DX - 0.72, ly=(Y_D1 + Y_D0)/2) + +# dec0 → output +arrow(DX, Y_D0 + BH/2 + .04, DX, Y_OUT - BH_IO/2 - .04) + +# --------------------------------------------------------------------------- +# Skip connections +# --------------------------------------------------------------------------- + +# enc0 skip → dec0 +arrow(EX + BW/2 + .04, Y_E0, + DX - BW/2 - .04, Y_D0, + color=C_SKP, lw=1.6, dashed=True, + label='skip enc0 (4 ch)', ly=Y_E0 + 0.40) + +# enc1 skip → dec1 +arrow(EX + BW/2 + .04, Y_E1, + DX - BW/2 - .04, Y_D1, + color=C_SKP, lw=1.6, dashed=True, + label='skip enc1 (8 ch)', ly=Y_E1 + 0.40) + +# --------------------------------------------------------------------------- +# FiLM γ/β arrows (MLP → each FiLM layer) +# --------------------------------------------------------------------------- +film_targets = [ + (EX, Y_E0 - BH/2 - .04), # enc0 bottom + (EX, Y_E1 - BH/2 - .04), # enc1 bottom + (DX, Y_D1 - BH/2 - .04), # dec1 bottom + (DX, Y_D0 - BH/2 - .04), # dec0 bottom +] +for tx, ty in film_targets: + ax.annotate('', xy=(tx, ty), + xytext=(BX + (tx - BX) * 0.05, Y_MLP + 0.55 + .04), + arrowprops=dict(arrowstyle='->', color=C_FILM, lw=1.2, + linestyle=(0, (3, 3)), mutation_scale=10, + connectionstyle='arc3,rad=0.18'), + zorder=2) + +ax.text(8.5, 4.30, 'γ / β', ha='center', va='center', + fontsize=9, color=C_FILM, alpha=0.85, style='italic', zorder=5) + +# --------------------------------------------------------------------------- +# Resolution markers (left margin) +# --------------------------------------------------------------------------- +for y, lbl in [(Y_E0, 'full res'), (Y_E1, '½ res'), (Y_BN, '¼ res')]: + dim_label(0.62, y, lbl) + ax.plot([0.95, 1.10], [y, y], color=C_DIM, lw=0.8, zorder=1) + +# --------------------------------------------------------------------------- +# Legend +# --------------------------------------------------------------------------- +legend_items = [ + mpatches.Patch(fc=C_ENC, ec='white', lw=0.8, label='Encoder'), + mpatches.Patch(fc=C_BN, ec='white', lw=0.8, label='Bottleneck'), + mpatches.Patch(fc=C_DEC, ec='white', lw=0.8, label='Decoder'), + mpatches.Patch(fc=C_MLP, ec='white', lw=0.8, label='FiLM MLP'), + mpatches.Patch(fc=C_IO, ec='white', lw=0.8, label='I/O'), + plt.Line2D([0], [0], color=C_SKP, lw=1.6, ls='--', label='Skip connection'), + plt.Line2D([0], [0], color=C_FILM, lw=1.2, ls=(0, (3,3)), label='FiLM γ/β'), +] +leg = ax.legend(handles=legend_items, loc='lower right', + bbox_to_anchor=(0.99, 0.01), + framealpha=0.15, facecolor=BG, edgecolor=C_DIM, + fontsize=8, labelcolor=C_TXT, ncol=1) + +# --------------------------------------------------------------------------- +# Title +# --------------------------------------------------------------------------- +ax.text(8.5, 9.68, 'CNN v3 — U-Net + FiLM Architecture', + ha='center', va='center', fontsize=14, fontweight='bold', color=C_TXT) + +# --------------------------------------------------------------------------- +# Save +# --------------------------------------------------------------------------- +import pathlib +out = pathlib.Path(__file__).parent / 'cnn_v3_architecture.png' +fig.savefig(out, dpi=180, bbox_inches='tight', facecolor=BG, edgecolor='none') +print(f'Saved → {out} ({out.stat().st_size // 1024} KB)') |
