diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-26 07:03:01 +0100 |
| commit | 8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch) | |
| tree | 2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/training/export_cnn_v3_weights.py | |
| parent | 4ca498277b033ae10134045dae9c8c249a8d2b2b (diff) | |
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch,
dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB).
FiLM MLP output 40→72 params (L1: 16×40→16×72).
16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck).
enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch).
GBUF_RGBA32UINT node gains CopySrc for parity test readback.
- WGSL shaders: all 5 passes rewritten for new channel counts
- C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs
- Web tool (shaders.js + tester.js): matching texture formats and bindings
- Parity test: readback_rgba32uint_8ch helper, updated vector counts
- Training scripts: default enc_channels=[8,16], updated docstrings
- Docs + architecture PNG regenerated
handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web
tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 51 |
1 files changed, 29 insertions, 22 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index 78f5f25..2fa83d1 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -15,12 +15,12 @@ Outputs <output_dir>/cnn_v3_weights.bin Conv+bias weights for all 5 passes, packed as f16-pairs-in-u32. Matches the format expected by CNNv3Effect::upload_weights(). - Layout: enc0 (724) | enc1 (296) | bottleneck (584) | dec1 (580) | dec0 (292) - = 2476 f16 values = 1238 u32 = 4952 bytes. + Layout: enc0 (1448) | enc1 (1168) | bottleneck (2320) | dec1 (2312) | dec0 (580) + = 7828 f16 values = 3914 u32 = 15656 bytes. <output_dir>/cnn_v3_film_mlp.bin - FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×40) L1_b (40). - = 5*16 + 16 + 16*40 + 40 = 80 + 16 + 640 + 40 = 776 f32 = 3104 bytes. + FiLM MLP weights as raw f32: L0_W (5×16) L0_b (16) L1_W (16×72) L1_b (72). + = 5*16 + 16 + 16*72 + 72 = 80 + 16 + 1152 + 72 = 1320 f32 = 5280 bytes. For future CPU-side MLP inference in CNNv3Effect::set_film_params(). Usage @@ -44,17 +44,19 @@ sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- -# Weight layout constants — must stay in sync with: -# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) -# cnn_v3/training/gen_test_vectors.py (same constants) +# Weight layout helpers — derived from enc_channels at runtime. +# Must stay in sync with cnn_v3/src/cnn_v3_effect.cc and gen_test_vectors.py. # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 -BN_WEIGHTS = 8 * 8 * 9 + 8 # Conv(8→8,3×3,dil=2)+bias = 584 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 -TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS -# = 2476 +N_IN = 20 # feature input channels (fixed) + +def weight_counts(enc_channels): + c0, c1 = enc_channels + enc0 = N_IN * c0 * 9 + c0 + enc1 = c0 * c1 * 9 + c1 + bn = c1 * c1 * 9 + c1 + dec1 = (c1 * 2) * c0 * 9 + c0 + dec0 = (c0 * 2) * 4 * 9 + 4 + return enc0, enc1, bn, dec1, dec0 def pack_weights_u32(w_f16: np.ndarray) -> np.ndarray: @@ -86,7 +88,7 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) cfg = ckpt.get('config', {}) - enc_channels = cfg.get('enc_channels', [4, 8]) + enc_channels = cfg.get('enc_channels', [8, 16]) film_cond_dim = cfg.get('film_cond_dim', 5) model = CNNv3(enc_channels=enc_channels, film_cond_dim=film_cond_dim) @@ -102,13 +104,18 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: # ----------------------------------------------------------------------- # 1. CNN conv weights → cnn_v3_weights.bin # ----------------------------------------------------------------------- + enc0_w, enc1_w, bn_w, dec1_w, dec0_w = weight_counts(enc_channels) + total_f16 = enc0_w + enc1_w + bn_w + dec1_w + dec0_w layers = [ - ('enc0', ENC0_WEIGHTS), - ('enc1', ENC1_WEIGHTS), - ('bottleneck', BN_WEIGHTS), - ('dec1', DEC1_WEIGHTS), - ('dec0', DEC0_WEIGHTS), + ('enc0', enc0_w), + ('enc1', enc1_w), + ('bottleneck', bn_w), + ('dec1', dec1_w), + ('dec0', dec0_w), ] + print(f" Weight layout: enc0={enc0_w} enc1={enc1_w} bn={bn_w} " + f"dec1={dec1_w} dec0={dec0_w} total={total_f16} f16 " + f"({total_f16*2/1024:.1f} KB)") all_f16 = [] for name, expected in layers: @@ -119,13 +126,13 @@ def export_weights(checkpoint_path: str, output_dir: str) -> None: all_f16.append(chunk) flat_f16 = np.concatenate(all_f16) - assert len(flat_f16) == TOTAL_F16, f"total mismatch: {len(flat_f16)} != {TOTAL_F16}" + assert len(flat_f16) == total_f16, f"total mismatch: {len(flat_f16)} != {total_f16}" packed_u32 = pack_weights_u32(flat_f16) weights_path = out / 'cnn_v3_weights.bin' packed_u32.astype('<u4').tofile(weights_path) # little-endian u32 print(f"\ncnn_v3_weights.bin") - print(f" {TOTAL_F16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes") + print(f" {total_f16} f16 values → {len(packed_u32)} u32 → {weights_path.stat().st_size} bytes") print(f" Upload via CNNv3Effect::upload_weights(queue, data, {len(packed_u32)*4})") # ----------------------------------------------------------------------- |
