summaryrefslogtreecommitdiff
path: root/cnn_v3/training/infer_cnn_v3.py
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
committerskal <pascal.massimino@gmail.com>2026-03-26 07:03:01 +0100
commit8f14bdd66cb002b2f89265b2a578ad93249089c9 (patch)
tree2ccdb3939b673ebc3a5df429160631240239cee2 /cnn_v3/training/infer_cnn_v3.py
parent4ca498277b033ae10134045dae9c8c249a8d2b2b (diff)
feat(cnn_v3): upgrade architecture to enc_channels=[8,16]
Double encoder capacity: enc0 4→8ch, enc1 8→16ch, bottleneck 16→16ch, dec1 32→8ch, dec0 16→4ch. Total weights 2476→7828 f16 (~15.3 KB). FiLM MLP output 40→72 params (L1: 16×40→16×72). 16-ch textures split into _lo/_hi rgba32uint pairs (enc1, bottleneck). enc0 and dec1 textures changed from rgba16float to rgba32uint (8ch). GBUF_RGBA32UINT node gains CopySrc for parity test readback. - WGSL shaders: all 5 passes rewritten for new channel counts - C++ CNNv3Effect: new weight offsets/sizes, 8ch uniform structs - Web tool (shaders.js + tester.js): matching texture formats and bindings - Parity test: readback_rgba32uint_8ch helper, updated vector counts - Training scripts: default enc_channels=[8,16], updated docstrings - Docs + architecture PNG regenerated handoff(Gemini): CNN v3 [8,16] upgrade complete. All code, tests, web tool, training scripts, and docs updated. Next: run training pass.
Diffstat (limited to 'cnn_v3/training/infer_cnn_v3.py')
-rw-r--r--cnn_v3/training/infer_cnn_v3.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/cnn_v3/training/infer_cnn_v3.py b/cnn_v3/training/infer_cnn_v3.py
index ca1c72a..b0fe9e6 100644
--- a/cnn_v3/training/infer_cnn_v3.py
+++ b/cnn_v3/training/infer_cnn_v3.py
@@ -129,8 +129,8 @@ def main():
p.add_argument('output', help='Output PNG')
p.add_argument('--checkpoint', '-c', metavar='CKPT',
help='Path to .pth checkpoint (auto-finds latest if omitted)')
- p.add_argument('--enc-channels', default='4,8',
- help='Encoder channels (default: 4,8 — must match checkpoint)')
+ p.add_argument('--enc-channels', default='8,16',
+ help='Encoder channels (default: 8,16 — must match checkpoint)')
p.add_argument('--cond', nargs=5, type=float, metavar='F', default=[0.0]*5,
help='FiLM conditioning: 5 floats (beat_phase beat_norm audio style0 style1)')
p.add_argument('--identity-film', action='store_true',