summaryrefslogtreecommitdiff
path: root/cnn_v3/docs/gen_architecture_png.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-25 10:05:42 +0100
committerskal <pascal.massimino@gmail.com>2026-03-25 10:05:42 +0100
commitce6e5b99f26e4e7c69a3cacf360bd0d492de928c (patch)
treea8d64b33a7ea1109b6b7e1043ced946cac416756 /cnn_v3/docs/gen_architecture_png.py
parent8b4d7a49f038d7e849e6764dcc3abd1e1be01061 (diff)
feat(cnn_v3): 3×3 dilated bottleneck + Sobel loss + FiLM warmup + architecture PNG
- Replace 1×1 pointwise bottleneck with Conv(8→8, 3×3, dilation=2): effective RF grows from ~13px to ~29px at ¼res (~+1 KB weights) - Add Sobel edge loss in training (--edge-loss-weight, default 0.1) - Add FiLM 2-phase training: freeze MLP for warmup epochs then unfreeze at lr×0.1 (--film-warmup-epochs, default 50) - Update weight layout: BN 72→584 f16, total 1964→2476 f16 (4952 B) - Cascade offsets in C++ effect, JS tool, export/gen_test_vectors scripts - Regenerate test_vectors.h (1238 u32); parity max_err=9.77e-04 - Generate dark-theme U-Net+FiLM architecture PNG (gen_architecture_png.py) - Replace ASCII art in CNN_V3.md and HOW_TO_CNN.md with PNG embed handoff(Gemini): bottleneck dilation + Sobel loss + FiLM warmup landed. Next: run first real training pass (see cnn_v3/docs/HOWTO.md §3).
Diffstat (limited to 'cnn_v3/docs/gen_architecture_png.py')
-rw-r--r--cnn_v3/docs/gen_architecture_png.py238
1 files changed, 238 insertions, 0 deletions
diff --git a/cnn_v3/docs/gen_architecture_png.py b/cnn_v3/docs/gen_architecture_png.py
new file mode 100644
index 0000000..bd60a97
--- /dev/null
+++ b/cnn_v3/docs/gen_architecture_png.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["matplotlib"]
+# ///
+"""Generate CNN v3 U-Net + FiLM architecture diagram → cnn_v3_architecture.png"""
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+from matplotlib.patches import FancyBboxPatch
+from matplotlib.path import Path
+import matplotlib.patheffects as pe
+
+# ---------------------------------------------------------------------------
+# Canvas
+# ---------------------------------------------------------------------------
+BG = '#0F172A'
+fig = plt.figure(figsize=(17, 10), facecolor=BG)
+ax = fig.add_axes([0, 0, 1, 1], facecolor=BG)
+ax.set_xlim(0, 17)
+ax.set_ylim(0, 10)
+ax.axis('off')
+
+# ---------------------------------------------------------------------------
+# Palette
+# ---------------------------------------------------------------------------
+C_ENC = '#3B82F6' # encoder — blue
+C_BN = '#8B5CF6' # bottleneck — violet
+C_DEC = '#10B981' # decoder — emerald
+C_MLP = '#EC4899' # FiLM MLP — pink
+C_FILM = '#F59E0B' # FiLM γ/β arrows — amber
+C_IO = '#475569' # input/output — slate
+C_SKP = '#F97316' # skip connections — orange
+C_ARR = '#94A3B8' # main flow arrows — cool-grey
+C_TXT = '#F1F5F9' # text — near-white
+C_DIM = '#64748B' # dim labels — slate
+
+# ---------------------------------------------------------------------------
+# Geometry — two-column U layout
+# ---------------------------------------------------------------------------
+EX, DX = 3.8, 13.2 # encoder / decoder centre-x
+BX = 8.5 # bottleneck centre-x
+
+BW = 4.6 # block width (enc / dec)
+BH = 0.95 # block height (enc / dec)
+BW_BN = 5.4 # bottleneck wider
+BH_BN = 0.95
+BH_IO = 0.72
+
+# y positions (top = high number)
+Y_IN = 8.90
+Y_E0 = 7.50 # enc0 full res
+Y_E1 = 5.80 # enc1 ½ res
+Y_BN = 3.20 # bottleneck ¼ res
+Y_D1 = 5.80 # dec1 ½ res
+Y_D0 = 7.50 # dec0 full res
+Y_OUT = 8.90
+
+Y_MLP = 1.25 # FiLM MLP
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def box(cx, cy, w, h, color, line1, line2='', lfs=9.5, sfs=8.0, alpha=0.92):
+ r = FancyBboxPatch((cx - w/2, cy - h/2), w, h,
+ boxstyle='round,pad=0.10',
+ fc=color, ec='white', lw=1.3, alpha=alpha, zorder=3)
+ ax.add_patch(r)
+ dy = 0.18 if line2 else 0
+ ax.text(cx, cy + dy, line1, ha='center', va='center',
+ fontsize=lfs, fontweight='bold', color='white', zorder=4,
+ fontfamily='DejaVu Sans Mono')
+ if line2:
+ ax.text(cx, cy - 0.18, line2, ha='center', va='center',
+ fontsize=sfs, color='white', alpha=0.80, zorder=4)
+
+
+def arrow(x0, y0, x1, y1, color=C_ARR, lw=1.8, dashed=False,
+ rad=0.0, label='', lx=None, ly=None):
+ ls = (0, (5, 3)) if dashed else 'solid'
+ cs = f'arc3,rad={rad}' if rad else 'arc3,rad=0'
+ ax.annotate('', xy=(x1, y1), xytext=(x0, y0),
+ arrowprops=dict(arrowstyle='->', color=color, lw=lw,
+ linestyle=ls, mutation_scale=13,
+ connectionstyle=cs),
+ zorder=2)
+ if label:
+ ax.text(lx if lx else (x0+x1)/2,
+ ly if ly else (y0+y1)/2,
+ label, ha='center', va='center', fontsize=7.5,
+ color=color, zorder=5,
+ bbox=dict(fc=BG, ec='none', alpha=0.75,
+ boxstyle='round,pad=0.15'))
+
+
+def dim_label(x, y, txt):
+ ax.text(x, y, txt, ha='center', va='center',
+ fontsize=8.5, color=C_DIM, style='italic')
+
+
+# ---------------------------------------------------------------------------
+# Blocks
+# ---------------------------------------------------------------------------
+
+box(EX, Y_IN, BW, BH_IO, C_IO, 'G-Buffer Features',
+ '20 channels · full res')
+
+box(EX, Y_E0, BW, BH, C_ENC, 'enc0 Conv(20→4, 3×3) + FiLM + ReLU',
+ 'full res · 4 ch')
+
+box(EX, Y_E1, BW, BH, C_ENC, 'enc1 Conv(4→8, 3×3) + FiLM + ReLU',
+ '½ res · 8 ch · (AvgPool↓ on input)')
+
+box(BX, Y_BN, BW_BN, BH_BN, C_BN,
+ 'bottleneck Conv(8→8, 3×3, dilation=2) + ReLU',
+ '¼ res · 8 ch · no FiLM · effective RF ≈ 10 px @ ½res')
+
+box(DX, Y_D1, BW, BH, C_DEC, 'dec1 Conv(16→4, 3×3) + FiLM + ReLU',
+ '½ res · 4 ch · (upsample↑ + cat enc1 skip)')
+
+box(DX, Y_D0, BW, BH, C_DEC, 'dec0 Conv(8→4, 3×3) + FiLM + sigmoid',
+ 'full res · 4 ch · (upsample↑ + cat enc0 skip)')
+
+box(DX, Y_OUT, BW, BH_IO, C_IO, 'RGBA Output',
+ '4 channels · full res')
+
+box(BX, Y_MLP, 9.2, 1.10, C_MLP,
+ 'FiLM MLP Linear(5→16) → ReLU → Linear(16→40)',
+ 'in: beat_phase · beat_norm · audio_intensity · style_p0 · style_p1'
+ ' → γ/β (×2) for enc0(4) enc1(8) dec1(4) dec0(4) = 40 values',
+ sfs=7.5)
+
+# ---------------------------------------------------------------------------
+# Main-flow arrows
+# ---------------------------------------------------------------------------
+
+# Input → enc0
+arrow(EX, Y_IN - BH_IO/2 - .04, EX, Y_E0 + BH/2 + .04)
+
+# enc0 → enc1 (AvgPool label beside)
+arrow(EX, Y_E0 - BH/2 - .04, EX, Y_E1 + BH/2 + .04,
+ label='AvgPool\n 2×2', lx=EX + 0.72, ly=(Y_E0 + Y_E1)/2)
+
+# enc1 → bottleneck (curve down-right)
+arrow(EX, Y_E1 - BH/2 - .04,
+ BX - BW_BN/2 - .04, Y_BN,
+ rad=-0.28,
+ label='AvgPool\n 2×2', lx=(EX + BX)/2 - 0.5, ly=Y_BN + 0.90)
+
+# bottleneck → dec1 (curve right-up)
+arrow(BX + BW_BN/2 + .04, Y_BN,
+ DX, Y_D1 - BH/2 - .04,
+ rad=-0.28,
+ label='upsample\n 2×', lx=(BX + DX)/2 + 0.5, ly=Y_D1 - 0.90)
+
+# dec1 → dec0
+arrow(DX, Y_D1 + BH/2 + .04, DX, Y_D0 - BH/2 - .04,
+ label='upsample\n 2×', lx=DX - 0.72, ly=(Y_D1 + Y_D0)/2)
+
+# dec0 → output
+arrow(DX, Y_D0 + BH/2 + .04, DX, Y_OUT - BH_IO/2 - .04)
+
+# ---------------------------------------------------------------------------
+# Skip connections
+# ---------------------------------------------------------------------------
+
+# enc0 skip → dec0
+arrow(EX + BW/2 + .04, Y_E0,
+ DX - BW/2 - .04, Y_D0,
+ color=C_SKP, lw=1.6, dashed=True,
+ label='skip enc0 (4 ch)', ly=Y_E0 + 0.40)
+
+# enc1 skip → dec1
+arrow(EX + BW/2 + .04, Y_E1,
+ DX - BW/2 - .04, Y_D1,
+ color=C_SKP, lw=1.6, dashed=True,
+ label='skip enc1 (8 ch)', ly=Y_E1 + 0.40)
+
+# ---------------------------------------------------------------------------
+# FiLM γ/β arrows (MLP → each FiLM layer)
+# ---------------------------------------------------------------------------
+film_targets = [
+ (EX, Y_E0 - BH/2 - .04), # enc0 bottom
+ (EX, Y_E1 - BH/2 - .04), # enc1 bottom
+ (DX, Y_D1 - BH/2 - .04), # dec1 bottom
+ (DX, Y_D0 - BH/2 - .04), # dec0 bottom
+]
+for tx, ty in film_targets:
+ ax.annotate('', xy=(tx, ty),
+ xytext=(BX + (tx - BX) * 0.05, Y_MLP + 0.55 + .04),
+ arrowprops=dict(arrowstyle='->', color=C_FILM, lw=1.2,
+ linestyle=(0, (3, 3)), mutation_scale=10,
+ connectionstyle='arc3,rad=0.18'),
+ zorder=2)
+
+ax.text(8.5, 4.30, 'γ / β', ha='center', va='center',
+ fontsize=9, color=C_FILM, alpha=0.85, style='italic', zorder=5)
+
+# ---------------------------------------------------------------------------
+# Resolution markers (left margin)
+# ---------------------------------------------------------------------------
+for y, lbl in [(Y_E0, 'full res'), (Y_E1, '½ res'), (Y_BN, '¼ res')]:
+ dim_label(0.62, y, lbl)
+ ax.plot([0.95, 1.10], [y, y], color=C_DIM, lw=0.8, zorder=1)
+
+# ---------------------------------------------------------------------------
+# Legend
+# ---------------------------------------------------------------------------
+legend_items = [
+ mpatches.Patch(fc=C_ENC, ec='white', lw=0.8, label='Encoder'),
+ mpatches.Patch(fc=C_BN, ec='white', lw=0.8, label='Bottleneck'),
+ mpatches.Patch(fc=C_DEC, ec='white', lw=0.8, label='Decoder'),
+ mpatches.Patch(fc=C_MLP, ec='white', lw=0.8, label='FiLM MLP'),
+ mpatches.Patch(fc=C_IO, ec='white', lw=0.8, label='I/O'),
+ plt.Line2D([0], [0], color=C_SKP, lw=1.6, ls='--', label='Skip connection'),
+ plt.Line2D([0], [0], color=C_FILM, lw=1.2, ls=(0, (3,3)), label='FiLM γ/β'),
+]
+leg = ax.legend(handles=legend_items, loc='lower right',
+ bbox_to_anchor=(0.99, 0.01),
+ framealpha=0.15, facecolor=BG, edgecolor=C_DIM,
+ fontsize=8, labelcolor=C_TXT, ncol=1)
+
+# ---------------------------------------------------------------------------
+# Title
+# ---------------------------------------------------------------------------
+ax.text(8.5, 9.68, 'CNN v3 — U-Net + FiLM Architecture',
+ ha='center', va='center', fontsize=14, fontweight='bold', color=C_TXT)
+
+# ---------------------------------------------------------------------------
+# Save
+# ---------------------------------------------------------------------------
+import pathlib
+out = pathlib.Path(__file__).parent / 'cnn_v3_architecture.png'
+fig.savefig(out, dpi=180, bbox_inches='tight', facecolor=BG, edgecolor='none')
+print(f'Saved → {out} ({out.stat().st_size // 1024} KB)')