diff options
| author | skal <pascal.massimino@gmail.com> | 2026-03-21 14:01:30 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-03-21 14:01:30 +0100 |
| commit | bf33fee131b1eee03bc5a765ba360299bbcead06 (patch) | |
| tree | b6a076ec977bb250a13b6a69be1092a183ae18ce /cnn_v3/training/export_cnn_v3_weights.py | |
| parent | 35355b17576e93b035a2a78ecd05771e98f068ee (diff) | |
refactor(cnn_v3): code review — comments, simplifications, test fix
C++:
- cnn_v3_effect.cc: fix declare_nodes comment (output node declared by caller)
- cnn_v3_effect.cc: add TODO(phase-7) marker for FiLM MLP replacement
WGSL:
- cnn_v3_bottleneck.wgsl: consolidate _pad fields onto one line, explain why
array<u32,3> is invalid in uniform address space
- cnn_v3_enc0.wgsl: fix "12xu8" → "12ch u8norm" in header comment
- cnn_v3_dec0.wgsl: clarify parity note (sigmoid after FiLM+ReLU, not raw conv)
- cnn_v3_common.wgsl: clarify unpack_8ch pack layout (low/high 16 bits)
Python:
- cnn_v3_utils.py: replace PIL-based _upsample_nearest (uint8 round-trip) with
pure numpy index arithmetic; rename _resize_rgb → _resize_img (handles any
channel count); add comment on normal zero-pad workaround
- export_cnn_v3_weights.py: add cross-ref to cnn_v3_effect.cc constants;
clarify weight count comments with Conv notation
Test:
- test_cnn_v3_parity.cc: enc0/dec1 layer failures now return 0 (were print-only)
handoff(Gemini): CNN v3 review complete, 36/36 tests passing.
Diffstat (limited to 'cnn_v3/training/export_cnn_v3_weights.py')
| -rw-r--r-- | cnn_v3/training/export_cnn_v3_weights.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/cnn_v3/training/export_cnn_v3_weights.py b/cnn_v3/training/export_cnn_v3_weights.py index a1ad42d..6d99af9 100644 --- a/cnn_v3/training/export_cnn_v3_weights.py +++ b/cnn_v3/training/export_cnn_v3_weights.py @@ -34,13 +34,15 @@ sys.path.insert(0, str(Path(__file__).parent)) from train_cnn_v3 import CNNv3 # --------------------------------------------------------------------------- -# Weight layout constants (must match cnn_v3_effect.cc and gen_test_vectors.py) +# Weight layout constants — must stay in sync with: +# cnn_v3/src/cnn_v3_effect.cc (kEnc0Weights, kEnc1Weights, …) +# cnn_v3/training/gen_test_vectors.py (same constants) # --------------------------------------------------------------------------- -ENC0_WEIGHTS = 20 * 4 * 9 + 4 # 724 -ENC1_WEIGHTS = 4 * 8 * 9 + 8 # 296 -BN_WEIGHTS = 8 * 8 * 1 + 8 # 72 -DEC1_WEIGHTS = 16 * 4 * 9 + 4 # 580 -DEC0_WEIGHTS = 8 * 4 * 9 + 4 # 292 +ENC0_WEIGHTS = 20 * 4 * 9 + 4 # Conv(20→4,3×3)+bias = 724 +ENC1_WEIGHTS = 4 * 8 * 9 + 8 # Conv(4→8,3×3)+bias = 296 +BN_WEIGHTS = 8 * 8 * 1 + 8 # Conv(8→8,1×1)+bias = 72 +DEC1_WEIGHTS = 16 * 4 * 9 + 4 # Conv(16→4,3×3)+bias = 580 +DEC0_WEIGHTS = 8 * 4 * 9 + 4 # Conv(8→4,3×3)+bias = 292 TOTAL_F16 = ENC0_WEIGHTS + ENC1_WEIGHTS + BN_WEIGHTS + DEC1_WEIGHTS + DEC0_WEIGHTS # = 1964 |
