diff options
Diffstat (limited to 'training/gen_identity_weights.py')
| -rwxr-xr-x | training/gen_identity_weights.py | 51 |
1 files changed, 44 insertions, 7 deletions
diff --git a/training/gen_identity_weights.py b/training/gen_identity_weights.py index a84ea87..7865d68 100755 --- a/training/gen_identity_weights.py +++ b/training/gen_identity_weights.py @@ -4,8 +4,16 @@ Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough. Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features). +With --mix: Output Ch{i} = 0.5*prev[i] + 0.5*static_p{4+i} + (50-50 blend of prev layer with uv_x, uv_y, sin20_y, bias) + +With --p47: Output Ch{i} = static p{4+i} (uv_x, uv_y, sin20_y, bias) + (p4/uv_x→ch0, p5/uv_y→ch1, p6/sin20_y→ch2, p7/bias→ch3) + Usage: ./training/gen_identity_weights.py [output.bin] + ./training/gen_identity_weights.py --mix [output.bin] + ./training/gen_identity_weights.py --p47 [output.bin] """ import argparse @@ -14,9 +22,15 @@ import struct from pathlib import Path -def generate_identity_weights(output_path, kernel_size=1, mip_level=0): +def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False): """Generate identity weights: output = input (ignores static features). + If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow). + If p47=True, transfers static p4-p7 (uv_x, uv_y, sin20_y, bias) to output channels. + + Input channel layout: [0-3: prev layer, 4-11: static (p0-p7)] + Static features: p0-p3 (RGB+D), p4 (uv_x), p5 (uv_y), p6 (sin20_y), p7 (bias) + Binary format: Header (20 bytes): uint32 magic ('CNN2') @@ -34,7 +48,8 @@ def generate_identity_weights(output_path, kernel_size=1, mip_level=0): Weights (u32 packed f16): Identity matrix for first 4 input channels - Zeros for static features (channels 4-11) + Zeros for static features (channels 4-11) OR + Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True """ # Identity: 4 output channels, 12 input channels # Weight shape: [out_ch, in_ch, kernel_h, kernel_w] @@ -47,19 +62,37 @@ def generate_identity_weights(output_path, kernel_size=1, mip_level=0): # Center position for kernel center = kernel_size // 2 - # Set diagonal to 1.0 (output ch i = input ch i) - for i in range(out_channels): - weights[i, i, center, center] = 1.0 + if p47: + # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only) + # Input channels: [0-3: prev layer, 4-11: static features (p0-p7)] + # p4-p7 are at input channels 8-11 + for i in range(out_channels): + weights[i, i + 8, center, center] = 1.0 + elif mix: + # Mix mode: 50-50 blend (p0+p4, p1+p5, p2+p6, p3+p7) + # p0-p3 are at channels 0-3 (prev layer), p4-p7 at channels 8-11 (static) + for i in range(out_channels): + weights[i, i, center, center] = 0.5 # 0.5*p{i} (prev layer) + weights[i, i + 8, center, center] = 0.5 # 0.5*p{i+4} (static) + else: + # Identity: output ch i = input ch i + for i in range(out_channels): + weights[i, i, center, center] = 1.0 # Flatten weights_flat = weights.flatten() weight_count = len(weights_flat) - print(f"Generating identity weights:") + mode_name = 'p47' if p47 else ('mix' if mix else 'identity') + print(f"Generating {mode_name} weights:") print(f" Kernel size: {kernel_size}×{kernel_size}") print(f" Channels: 12D→4D") print(f" Weights: {weight_count}") print(f" Mip level: {mip_level}") + if mix: + print(f" Mode: 0.5*prev[i] + 0.5*static_p{{4+i}} (blend with uv/sin/bias)") + elif p47: + print(f" Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3") # Convert to f16 weights_f16 = np.array(weights_flat, dtype=np.float16) @@ -122,11 +155,15 @@ def main(): help='Kernel size (default: 1×1)') parser.add_argument('--mip-level', type=int, default=0, help='Mip level for p0-p3 features (default: 0)') + parser.add_argument('--mix', action='store_true', + help='Mix mode: 50-50 blend of p0-p3 and p4-p7') + parser.add_argument('--p47', action='store_true', + help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3') args = parser.parse_args() print("=== Identity Weight Generator ===\n") - generate_identity_weights(args.output, args.kernel_size, args.mip_level) + generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47) print("\nDone!") |
