summaryrefslogtreecommitdiff
path: root/training/gen_identity_weights.py
blob: 5756e671da432603097693e6770583ab4e228de4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
"""Generate Identity CNN v2 Weights

Creates trivial .bin with 1 layer, 1×1 kernel, identity passthrough.
Output Ch{0,1,2,3} = Input Ch{0,1,2,3} (ignores static features).

With --mix: Output Ch{i} = 0.5*Input Ch{i} + 0.5*Input Ch{i+4}
  (50-50 blend, avoids overflow)

With --p47: Output Ch{i} = Input Ch{i+4} (static features only)
  (p4→ch0, p5→ch1, p6→ch2, p7→ch3)

Usage:
  ./training/gen_identity_weights.py [output.bin]
  ./training/gen_identity_weights.py --mix [output.bin]
  ./training/gen_identity_weights.py --p47 [output.bin]
"""

import argparse
import numpy as np
import struct
from pathlib import Path


def generate_identity_weights(output_path, kernel_size=1, mip_level=0, mix=False, p47=False):
    """Generate identity weights: output = input (ignores static features).

    If mix=True, 50-50 blend: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, etc (avoids overflow).
    If p47=True, transfers p4→p0, p5→p1, p6→p2, p7→p3 (static features only).

    Binary format:
      Header (20 bytes):
        uint32 magic ('CNN2')
        uint32 version (2)
        uint32 num_layers (1)
        uint32 total_weights (f16 count)
        uint32 mip_level

      LayerInfo (20 bytes):
        uint32 kernel_size
        uint32 in_channels (12)
        uint32 out_channels (4)
        uint32 weight_offset (0)
        uint32 weight_count

      Weights (u32 packed f16):
        Identity matrix for first 4 input channels
        Zeros for static features (channels 4-11) OR
        Mix matrix (p0+p4, p1+p5, p2+p6, p3+p7) if mix=True
    """
    # Identity: 4 output channels, 12 input channels
    # Weight shape: [out_ch, in_ch, kernel_h, kernel_w]
    in_channels = 12  # 4 input + 8 static
    out_channels = 4

    # Identity matrix: diagonal 1.0 for first 4 channels, 0.0 for rest
    weights = np.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=np.float32)

    # Center position for kernel
    center = kernel_size // 2

    if p47:
        # p47 mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3 (static features only)
        for i in range(out_channels):
            weights[i, i + 4, center, center] = 1.0
    elif mix:
        # Mix mode: 50-50 blend to avoid overflow
        for i in range(out_channels):
            weights[i, i, center, center] = 0.5       # 0.5*p{i}
            weights[i, i + 4, center, center] = 0.5   # 0.5*p{i+4}
    else:
        # Identity: output ch i = input ch i
        for i in range(out_channels):
            weights[i, i, center, center] = 1.0

    # Flatten
    weights_flat = weights.flatten()
    weight_count = len(weights_flat)

    mode_name = 'p47' if p47 else ('mix' if mix else 'identity')
    print(f"Generating {mode_name} weights:")
    print(f"  Kernel size: {kernel_size}×{kernel_size}")
    print(f"  Channels: 12D→4D")
    print(f"  Weights: {weight_count}")
    print(f"  Mip level: {mip_level}")
    if mix:
        print(f"  Mode: 0.5*p0+0.5*p4, 0.5*p1+0.5*p5, 0.5*p2+0.5*p6, 0.5*p3+0.5*p7")
    elif p47:
        print(f"  Mode: p4→ch0, p5→ch1, p6→ch2, p7→ch3")

    # Convert to f16
    weights_f16 = np.array(weights_flat, dtype=np.float16)

    # Pad to even count
    if len(weights_f16) % 2 == 1:
        weights_f16 = np.append(weights_f16, np.float16(0.0))

    # Pack f16 pairs into u32
    weights_u32 = weights_f16.view(np.uint32)

    print(f"  Packed: {len(weights_u32)} u32")
    print(f"  Binary size: {20 + 20 + len(weights_u32) * 4} bytes")

    # Write binary
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, 'wb') as f:
        # Header (20 bytes)
        f.write(struct.pack('<4sIIII',
                           b'CNN2',           # magic
                           2,                 # version
                           1,                 # num_layers
                           len(weights_f16),  # total_weights
                           mip_level))        # mip_level

        # Layer info (20 bytes)
        f.write(struct.pack('<IIIII',
                           kernel_size,       # kernel_size
                           in_channels,       # in_channels
                           out_channels,      # out_channels
                           0,                 # weight_offset
                           weight_count))     # weight_count

        # Weights (u32 packed f16)
        f.write(weights_u32.tobytes())

    print(f"  → {output_path}")

    # Verify
    print("\nVerification:")
    with open(output_path, 'rb') as f:
        data = f.read()
        magic, version, num_layers, total_weights, mip = struct.unpack('<4sIIII', data[:20])
        print(f"  Magic: {magic}")
        print(f"  Version: {version}")
        print(f"  Layers: {num_layers}")
        print(f"  Total weights: {total_weights}")
        print(f"  Mip level: {mip}")
        print(f"  File size: {len(data)} bytes")


def main():
    parser = argparse.ArgumentParser(description='Generate identity CNN v2 weights')
    parser.add_argument('output', type=str, nargs='?',
                       default='workspaces/main/weights/cnn_v2_identity.bin',
                       help='Output .bin file path')
    parser.add_argument('--kernel-size', type=int, default=1,
                       help='Kernel size (default: 1×1)')
    parser.add_argument('--mip-level', type=int, default=0,
                       help='Mip level for p0-p3 features (default: 0)')
    parser.add_argument('--mix', action='store_true',
                       help='Mix mode: 50-50 blend of p0-p3 and p4-p7')
    parser.add_argument('--p47', action='store_true',
                       help='Static features only: p4→ch0, p5→ch1, p6→ch2, p7→ch3')

    args = parser.parse_args()

    print("=== Identity Weight Generator ===\n")
    generate_identity_weights(args.output, args.kernel_size, args.mip_level, args.mix, args.p47)
    print("\nDone!")


if __name__ == '__main__':
    main()