summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_weights.py
blob: d8c7c1062f116e786a86e06c506d2c844ad1207f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#!/usr/bin/env python3
"""CNN v2 Weight Export Script

Converts PyTorch checkpoints to binary weight format for storage buffer.
Exports single shader template + binary weights asset.
"""

import argparse
import numpy as np
import torch
import struct
from pathlib import Path


def export_weights_binary(checkpoint_path, output_path):
    """Export CNN v2 weights to binary format.

    Binary format:
      Header (16 bytes):
        uint32 magic ('CNN2')
        uint32 version (1)
        uint32 num_layers
        uint32 total_weights (f16 count)

      LayerInfo × num_layers (20 bytes each):
        uint32 kernel_size
        uint32 in_channels
        uint32 out_channels
        uint32 weight_offset (f16 index)
        uint32 weight_count

      Weights (f16 array):
        float16[] all_weights

    Args:
        checkpoint_path: Path to .pth checkpoint
        output_path: Output .bin file path

    Returns:
        config dict for shader generation
    """
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    config = checkpoint['config']

    print(f"Configuration:")
    print(f"  Kernels: {config['kernels']}")
    print(f"  Channels: {config['channels']}")

    # Collect layer info
    layers = []
    all_weights = []
    weight_offset = 0

    # Layer 0: 8 → channels[0]
    layer0_weights = state_dict['layer0.weight'].detach().numpy()
    layer0_flat = layer0_weights.flatten()
    layers.append({
        'kernel_size': config['kernels'][0],
        'in_channels': 8,
        'out_channels': config['channels'][0],
        'weight_offset': weight_offset,
        'weight_count': len(layer0_flat)
    })
    all_weights.extend(layer0_flat)
    weight_offset += len(layer0_flat)

    # Layer 1: (8 + channels[0]) → channels[1]
    layer1_weights = state_dict['layer1.weight'].detach().numpy()
    layer1_flat = layer1_weights.flatten()
    layers.append({
        'kernel_size': config['kernels'][1],
        'in_channels': 8 + config['channels'][0],
        'out_channels': config['channels'][1],
        'weight_offset': weight_offset,
        'weight_count': len(layer1_flat)
    })
    all_weights.extend(layer1_flat)
    weight_offset += len(layer1_flat)

    # Layer 2: (8 + channels[1]) → 4 (RGBA output)
    layer2_weights = state_dict['layer2.weight'].detach().numpy()
    layer2_flat = layer2_weights.flatten()
    layers.append({
        'kernel_size': config['kernels'][2],
        'in_channels': 8 + config['channels'][1],
        'out_channels': 4,
        'weight_offset': weight_offset,
        'weight_count': len(layer2_flat)
    })
    all_weights.extend(layer2_flat)
    weight_offset += len(layer2_flat)

    # Convert to f16
    # TODO: Use 8-bit quantization for 2× size reduction
    # Requires quantization-aware training (QAT) to maintain accuracy
    all_weights_f16 = np.array(all_weights, dtype=np.float16)

    # Pack f16 pairs into u32 for storage buffer
    # Pad to even count if needed
    if len(all_weights_f16) % 2 == 1:
        all_weights_f16 = np.append(all_weights_f16, np.float16(0.0))

    # Pack pairs using numpy view
    weights_u32 = all_weights_f16.view(np.uint32)

    print(f"\nWeight statistics:")
    print(f"  Total layers: {len(layers)}")
    print(f"  Total weights: {len(all_weights_f16)} (f16)")
    print(f"  Packed: {len(weights_u32)} u32")
    print(f"  Binary size: {16 + len(layers) * 20 + len(weights_u32) * 4} bytes")

    # Write binary file
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, 'wb') as f:
        # Header (16 bytes)
        f.write(struct.pack('<4sIII',
                           b'CNN2',           # magic
                           1,                 # version
                           len(layers),       # num_layers
                           len(all_weights_f16)))  # total_weights (f16 count)

        # Layer info (20 bytes per layer)
        for layer in layers:
            f.write(struct.pack('<IIIII',
                               layer['kernel_size'],
                               layer['in_channels'],
                               layer['out_channels'],
                               layer['weight_offset'],
                               layer['weight_count']))

        # Weights (u32 packed f16 pairs)
        f.write(weights_u32.tobytes())

    print(f"  → {output_path}")

    return {
        'num_layers': len(layers),
        'layers': layers
    }


def export_shader_template(config, output_dir):
    """Generate single WGSL shader template with storage buffer binding.

    Args:
        config: Layer configuration from export_weights_binary()
        output_dir: Output directory path
    """
    shader_code = """// CNN v2 Compute Shader - Storage Buffer Version
// Reads weights from storage buffer, processes all layers in sequence

struct CNNv2Header {
  magic: u32,           // 'CNN2'
  version: u32,         // 1
  num_layers: u32,      // Number of layers
  total_weights: u32,   // Total f16 weight count
}

struct CNNv2LayerInfo {
  kernel_size: u32,
  in_channels: u32,
  out_channels: u32,
  weight_offset: u32,   // Offset in weights array
  weight_count: u32,
}

@group(0) @binding(0) var static_features: texture_2d<u32>;
@group(0) @binding(1) var layer_input: texture_2d<u32>;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;
@group(0) @binding(3) var<storage, read> weights: array<u32>;  // Packed f16 pairs

fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
  let packed = textureLoad(static_features, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  let v2 = unpack2x16float(packed.z);
  let v3 = unpack2x16float(packed.w);
  return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}

fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
  let packed = textureLoad(layer_input, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  let v2 = unpack2x16float(packed.z);
  let v3 = unpack2x16float(packed.w);
  return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}

fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
  return vec4<u32>(
    pack2x16float(vec2<f32>(values[0], values[1])),
    pack2x16float(vec2<f32>(values[2], values[3])),
    pack2x16float(vec2<f32>(values[4], values[5])),
    pack2x16float(vec2<f32>(values[6], values[7]))
  );
}

fn get_weight(idx: u32) -> f32 {
  let pair_idx = idx / 2u;
  let packed = weights[8u + pair_idx];  // Skip header (32 bytes = 8 u32)
  let unpacked = unpack2x16float(packed);
  return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
}

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
  let coord = vec2<i32>(id.xy);
  let dims = textureDimensions(static_features);

  if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
    return;
  }

  // Read header
  let header_packed = weights[0];  // magic + version
  let counts_packed = weights[1];  // num_layers + total_weights
  let num_layers = counts_packed & 0xFFFFu;

  // Load static features
  let static_feat = unpack_static_features(coord);

  // Process each layer (hardcoded for 3 layers for now)
  // TODO: Dynamic layer loop when needed

  // Example for layer 0 - expand to full multi-layer when tested
  let layer_info_offset = 2u;  // After header
  let layer0_info_base = layer_info_offset;

  // Read layer 0 info (5 u32 values = 20 bytes)
  let kernel_size = weights[layer0_info_base];
  let in_channels = weights[layer0_info_base + 1u];
  let out_channels = weights[layer0_info_base + 2u];
  let weight_offset = weights[layer0_info_base + 3u];

  // Convolution (simplified - expand to full kernel loop)
  var output: array<f32, 8>;
  for (var c: u32 = 0u; c < min(out_channels, 8u); c++) {
    output[c] = 0.0;  // TODO: Actual convolution
  }

  textureStore(output_tex, coord, pack_channels(output));
}
"""

    output_path = Path(output_dir) / "cnn_v2" / "cnn_v2_compute.wgsl"
    output_path.write_text(shader_code)
    print(f"  → {output_path}")


def main():
    parser = argparse.ArgumentParser(description='Export CNN v2 weights to binary format')
    parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
    parser.add_argument('--output-weights', type=str, default='workspaces/main/cnn_v2_weights.bin',
                        help='Output binary weights file')
    parser.add_argument('--output-shader', type=str, default='workspaces/main/shaders',
                        help='Output directory for shader template')

    args = parser.parse_args()

    print("=== CNN v2 Weight Export ===\n")
    config = export_weights_binary(args.checkpoint, args.output_weights)
    print()
    # Shader is manually maintained in cnn_v2_compute.wgsl
    # export_shader_template(config, args.output_shader)
    print("\nExport complete!")


if __name__ == '__main__':
    main()