diff options
Diffstat (limited to 'training')
| -rwxr-xr-x | training/gen_identity_weights.py | 33 |
1 files changed, 24 insertions, 9 deletions
diff --git a/training/gen_identity_weights.py b/training/gen_identity_weights.py index c996758..0d79593 100755 --- a/training/gen_identity_weights.py +++ b/training/gen_identity_weights.py @@ -7,9 +7,13 @@ Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features). With --mix: Output Ch{i} = Input Ch{i} + Input Ch{i+4} (p0+p4, p1+p5, p2+p6, p3+p7) +With --p47: Output Ch{i} = Input Ch{i+4} (static features only) + (p4→ch0, p5→ch1, p6→ch2, p7→ch3) + Usage: ./training/gen_identity_weights.py [output.bin] ./training/gen_identity_weights.py --mix [output.bin] + ./training/gen_identity_weights.py --p47 [output.bin] """ import argparse @@ -18,10 +22,11 @@ import struct from pathlib import Path -def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False): +def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False): """Generate identity weights: output = input (ignores static features). If mix=True, adds p4→p0, p5→p1, p6→p2, p7→p3 (blends input with static). + If p47=True, transfers p4→p0, p5→p1, p6→p2, p7→p3 (static features only). Binary format: Header (20 bytes): @@ -54,26 +59,34 @@ def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False # Center position for kernel center = kernel_size // 2 - # Set diagonal to 1.0 (output ch i = input ch i) - for i in range(out_channels): - weights[i, i, center, center] = 1.0 - - # If mix, add p4→p0, p5→p1, p6→p2, p7→p3 - if mix: + if p47: + # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only) for i in range(out_channels): weights[i, i + 4, center, center] = 1.0 + else: + # Set diagonal to 1.0 (output ch i = input ch i) + for i in range(out_channels): + weights[i, i, center, center] = 1.0 + + # If mix, add p4→p0, p5→p1, p6→p2, p7→p3 + if mix: + for i in range(out_channels): + weights[i, i + 4, center, center] = 1.0 # Flatten weights_flat = weights.flatten() weight_count = len(weights_flat) - print(f"Generating {'mix' if mix else 'identity'} weights:") + mode_name = 'p47' if p47 else ('mix' if mix else 'identity') + print(f"Generating {mode_name} weights:") print(f" Kernel size: {kernel_size}×{kernel_size}") print(f" Channels: 12D→4D") print(f" Weights: {weight_count}") print(f" Mip level: {mip_level}") if mix: print(f" Mode: p0+p4, p1+p5, p2+p6, p3+p7") + elif p47: + print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3") # Convert to f16 weights_f16 = np.array(weights_flat, dtype=np.float16) @@ -138,11 +151,13 @@ def main(): help='Mip level for p0-p3 features (default: 0)') parser.add_argument('--mix', action='store_true', help='Mix mode: p0+p4, p1+p5, p2+p6, p3+p7') + parser.add_argument('--p47', action='store_true', + help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3') args = parser.parse_args() print("=== Identity Weight Generator ===\n") - generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix) + generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47) print("\nDone!") |
