#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["matplotlib"] # /// """Generate CNN v3 U-Net + FiLM architecture diagram → cnn_v3_architecture.png""" import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.patches import FancyBboxPatch from matplotlib.path import Path import matplotlib.patheffects as pe # --------------------------------------------------------------------------- # Canvas # --------------------------------------------------------------------------- BG = '#0F172A' fig = plt.figure(figsize=(17, 10), facecolor=BG) ax = fig.add_axes([0, 0, 1, 1], facecolor=BG) ax.set_xlim(0, 17) ax.set_ylim(0, 10) ax.axis('off') # --------------------------------------------------------------------------- # Palette # --------------------------------------------------------------------------- C_ENC = '#3B82F6' # encoder — blue C_BN = '#8B5CF6' # bottleneck — violet C_DEC = '#10B981' # decoder — emerald C_MLP = '#EC4899' # FiLM MLP — pink C_FILM = '#F59E0B' # FiLM γ/β arrows — amber C_IO = '#475569' # input/output — slate C_SKP = '#F97316' # skip connections — orange C_ARR = '#94A3B8' # main flow arrows — cool-grey C_TXT = '#F1F5F9' # text — near-white C_DIM = '#64748B' # dim labels — slate # --------------------------------------------------------------------------- # Geometry — two-column U layout # --------------------------------------------------------------------------- EX, DX = 3.8, 13.2 # encoder / decoder centre-x BX = 8.5 # bottleneck centre-x BW = 4.6 # block width (enc / dec) BH = 0.95 # block height (enc / dec) BW_BN = 5.4 # bottleneck wider BH_BN = 0.95 BH_IO = 0.72 # y positions (top = high number) Y_IN = 8.90 Y_E0 = 7.50 # enc0 full res Y_E1 = 5.80 # enc1 ½ res Y_BN = 3.20 # bottleneck ¼ res Y_D1 = 5.80 # dec1 ½ res Y_D0 = 7.50 # dec0 full res Y_OUT = 8.90 Y_MLP = 1.25 # FiLM MLP # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def box(cx, cy, w, h, color, line1, line2='', lfs=9.5, sfs=8.0, alpha=0.92): r = FancyBboxPatch((cx - w/2, cy - h/2), w, h, boxstyle='round,pad=0.10', fc=color, ec='white', lw=1.3, alpha=alpha, zorder=3) ax.add_patch(r) dy = 0.18 if line2 else 0 ax.text(cx, cy + dy, line1, ha='center', va='center', fontsize=lfs, fontweight='bold', color='white', zorder=4, fontfamily='DejaVu Sans Mono') if line2: ax.text(cx, cy - 0.18, line2, ha='center', va='center', fontsize=sfs, color='white', alpha=0.80, zorder=4) def arrow(x0, y0, x1, y1, color=C_ARR, lw=1.8, dashed=False, rad=0.0, label='', lx=None, ly=None): ls = (0, (5, 3)) if dashed else 'solid' cs = f'arc3,rad={rad}' if rad else 'arc3,rad=0' ax.annotate('', xy=(x1, y1), xytext=(x0, y0), arrowprops=dict(arrowstyle='->', color=color, lw=lw, linestyle=ls, mutation_scale=13, connectionstyle=cs), zorder=2) if label: ax.text(lx if lx else (x0+x1)/2, ly if ly else (y0+y1)/2, label, ha='center', va='center', fontsize=7.5, color=color, zorder=5, bbox=dict(fc=BG, ec='none', alpha=0.75, boxstyle='round,pad=0.15')) def dim_label(x, y, txt): ax.text(x, y, txt, ha='center', va='center', fontsize=8.5, color=C_DIM, style='italic') # --------------------------------------------------------------------------- # Blocks # --------------------------------------------------------------------------- box(EX, Y_IN, BW, BH_IO, C_IO, 'G-Buffer Features', '20 channels · full res') box(EX, Y_E0, BW, BH, C_ENC, 'enc0 Conv(20→4, 3×3) + FiLM + ReLU', 'full res · 4 ch') box(EX, Y_E1, BW, BH, C_ENC, 'enc1 Conv(4→8, 3×3) + FiLM + ReLU', '½ res · 8 ch · (AvgPool↓ on input)') box(BX, Y_BN, BW_BN, BH_BN, C_BN, 'bottleneck Conv(8→8, 3×3, dilation=2) + ReLU', '¼ res · 8 ch · no FiLM · effective RF ≈ 10 px @ ½res') box(DX, Y_D1, BW, BH, C_DEC, 'dec1 Conv(16→4, 3×3) + FiLM + ReLU', '½ res · 4 ch · (upsample↑ + cat enc1 skip)') box(DX, Y_D0, BW, BH, C_DEC, 'dec0 Conv(8→4, 3×3) + FiLM + sigmoid', 'full res · 4 ch · (upsample↑ + cat enc0 skip)') box(DX, Y_OUT, BW, BH_IO, C_IO, 'RGBA Output', '4 channels · full res') box(BX, Y_MLP, 9.2, 1.10, C_MLP, 'FiLM MLP Linear(5→16) → ReLU → Linear(16→40)', 'in: beat_phase · beat_norm · audio_intensity · style_p0 · style_p1' ' → γ/β (×2) for enc0(4) enc1(8) dec1(4) dec0(4) = 40 values', sfs=7.5) # --------------------------------------------------------------------------- # Main-flow arrows # --------------------------------------------------------------------------- # Input → enc0 arrow(EX, Y_IN - BH_IO/2 - .04, EX, Y_E0 + BH/2 + .04) # enc0 → enc1 (AvgPool label beside) arrow(EX, Y_E0 - BH/2 - .04, EX, Y_E1 + BH/2 + .04, label='AvgPool\n 2×2', lx=EX + 0.72, ly=(Y_E0 + Y_E1)/2) # enc1 → bottleneck (curve down-right) arrow(EX, Y_E1 - BH/2 - .04, BX - BW_BN/2 - .04, Y_BN, rad=-0.28, label='AvgPool\n 2×2', lx=(EX + BX)/2 - 0.5, ly=Y_BN + 0.90) # bottleneck → dec1 (curve right-up) arrow(BX + BW_BN/2 + .04, Y_BN, DX, Y_D1 - BH/2 - .04, rad=-0.28, label='upsample\n 2×', lx=(BX + DX)/2 + 0.5, ly=Y_D1 - 0.90) # dec1 → dec0 arrow(DX, Y_D1 + BH/2 + .04, DX, Y_D0 - BH/2 - .04, label='upsample\n 2×', lx=DX - 0.72, ly=(Y_D1 + Y_D0)/2) # dec0 → output arrow(DX, Y_D0 + BH/2 + .04, DX, Y_OUT - BH_IO/2 - .04) # --------------------------------------------------------------------------- # Skip connections # --------------------------------------------------------------------------- # enc0 skip → dec0 arrow(EX + BW/2 + .04, Y_E0, DX - BW/2 - .04, Y_D0, color=C_SKP, lw=1.6, dashed=True, label='skip enc0 (4 ch)', ly=Y_E0 + 0.40) # enc1 skip → dec1 arrow(EX + BW/2 + .04, Y_E1, DX - BW/2 - .04, Y_D1, color=C_SKP, lw=1.6, dashed=True, label='skip enc1 (8 ch)', ly=Y_E1 + 0.40) # --------------------------------------------------------------------------- # FiLM γ/β arrows (MLP → each FiLM layer) # --------------------------------------------------------------------------- film_targets = [ (EX, Y_E0 - BH/2 - .04), # enc0 bottom (EX, Y_E1 - BH/2 - .04), # enc1 bottom (DX, Y_D1 - BH/2 - .04), # dec1 bottom (DX, Y_D0 - BH/2 - .04), # dec0 bottom ] for tx, ty in film_targets: ax.annotate('', xy=(tx, ty), xytext=(BX + (tx - BX) * 0.05, Y_MLP + 0.55 + .04), arrowprops=dict(arrowstyle='->', color=C_FILM, lw=1.2, linestyle=(0, (3, 3)), mutation_scale=10, connectionstyle='arc3,rad=0.18'), zorder=2) ax.text(8.5, 4.30, 'γ / β', ha='center', va='center', fontsize=9, color=C_FILM, alpha=0.85, style='italic', zorder=5) # --------------------------------------------------------------------------- # Resolution markers (left margin) # --------------------------------------------------------------------------- for y, lbl in [(Y_E0, 'full res'), (Y_E1, '½ res'), (Y_BN, '¼ res')]: dim_label(0.62, y, lbl) ax.plot([0.95, 1.10], [y, y], color=C_DIM, lw=0.8, zorder=1) # --------------------------------------------------------------------------- # Legend # --------------------------------------------------------------------------- legend_items = [ mpatches.Patch(fc=C_ENC, ec='white', lw=0.8, label='Encoder'), mpatches.Patch(fc=C_BN, ec='white', lw=0.8, label='Bottleneck'), mpatches.Patch(fc=C_DEC, ec='white', lw=0.8, label='Decoder'), mpatches.Patch(fc=C_MLP, ec='white', lw=0.8, label='FiLM MLP'), mpatches.Patch(fc=C_IO, ec='white', lw=0.8, label='I/O'), plt.Line2D([0], [0], color=C_SKP, lw=1.6, ls='--', label='Skip connection'), plt.Line2D([0], [0], color=C_FILM, lw=1.2, ls=(0, (3,3)), label='FiLM γ/β'), ] leg = ax.legend(handles=legend_items, loc='lower right', bbox_to_anchor=(0.99, 0.01), framealpha=0.15, facecolor=BG, edgecolor=C_DIM, fontsize=8, labelcolor=C_TXT, ncol=1) # --------------------------------------------------------------------------- # Title # --------------------------------------------------------------------------- ax.text(8.5, 9.68, 'CNN v3 — U-Net + FiLM Architecture', ha='center', va='center', fontsize=14, fontweight='bold', color=C_TXT) # --------------------------------------------------------------------------- # Save # --------------------------------------------------------------------------- import pathlib out = pathlib.Path(__file__).parent / 'cnn_v3_architecture.png' fig.savefig(out, dpi=180, bbox_inches='tight', facecolor=BG, edgecolor='none') print(f'Saved → {out} ({out.stat().st_size // 1024} KB)')