summaryrefslogtreecommitdiff
path: root/training/gen_identity_weights.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/gen_identity_weights.py')
-rwxr-xr-xtraining/gen_identity_weights.py33
1 files changed, 24 insertions, 9 deletions
diff --git a/training/gen_identity_weights.py b/training/gen_identity_weights.py
index c996758..0d79593 100755
--- a/training/gen_identity_weights.py
+++ b/training/gen_identity_weights.py
@@ -7,9 +7,13 @@ Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features).
With --mix: Output Ch{i} = Input Ch{i} + Input Ch{i+4}
(p0+p4, p1+p5, p2+p6, p3+p7)
+With --p47: Output Ch{i} = Input Ch{i+4} (static features only)
+ (p4→ch0, p5→ch1, p6→ch2, p7→ch3)
+
Usage:
./training/gen_identity_weights.py [output.bin]
./training/gen_identity_weights.py --mix [output.bin]
+ ./training/gen_identity_weights.py --p47 [output.bin]
"""
import argparse
@@ -18,10 +22,11 @@ import struct
from pathlib import Path
-def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False):
+def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False):
"""Generate identity weights: output = input (ignores static features).
If mix=True, adds p4→p0, p5→p1, p6→p2, p7→p3 (blends input with static).
+ If p47=True, transfers p4→p0, p5→p1, p6→p2, p7→p3 (static features only).
Binary format:
Header (20 bytes):
@@ -54,26 +59,34 @@ def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False
# Center position for kernel
center = kernel_size // 2
- # Set diagonal to 1.0 (output ch i = input ch i)
- for i in range(out_channels):
- weights[i, i, center, center] = 1.0
-
- # If mix, add p4→p0, p5→p1, p6→p2, p7→p3
- if mix:
+ if p47:
+ # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only)
for i in range(out_channels):
weights[i, i + 4, center, center] = 1.0
+ else:
+ # Set diagonal to 1.0 (output ch i = input ch i)
+ for i in range(out_channels):
+ weights[i, i, center, center] = 1.0
+
+ # If mix, add p4→p0, p5→p1, p6→p2, p7→p3
+ if mix:
+ for i in range(out_channels):
+ weights[i, i + 4, center, center] = 1.0
# Flatten
weights_flat = weights.flatten()
weight_count = len(weights_flat)
- print(f"Generating {'mix' if mix else 'identity'} weights:")
+ mode_name = 'p47' if p47 else ('mix' if mix else 'identity')
+ print(f"Generating {mode_name} weights:")
print(f" Kernel size: {kernel_size}×{kernel_size}")
print(f" Channels: 12D→4D")
print(f" Weights: {weight_count}")
print(f" Mip level: {mip_level}")
if mix:
print(f" Mode: p0+p4, p1+p5, p2+p6, p3+p7")
+ elif p47:
+ print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3")
# Convert to f16
weights_f16 = np.array(weights_flat, dtype=np.float16)
@@ -138,11 +151,13 @@ def main():
help='Mip level for p0-p3 features (default: 0)')
parser.add_argument('--mix', action='store_true',
help='Mix mode: p0+p4, p1+p5, p2+p6, p3+p7')
+ parser.add_argument('--p47', action='store_true',
+ help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3')
args = parser.parse_args()
print("=== Identity Weight Generator ===\n")
- generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix)
+ generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47)
print("\nDone!")