summaryrefslogtreecommitdiff
path: root/training/export_cnn_v2_shader.py
blob: 1c74ad03240d1fe2255db4defe3ebbf84ad69d47 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python3
"""CNN v2 Shader Export Script - Uniform 12D→4D Architecture

Converts PyTorch checkpoints to WGSL compute shaders with f16 weights.
Generates one shader per layer with embedded weight arrays.

Note: Storage buffer approach (export_cnn_v2_weights.py) is preferred for size.
      This script is for debugging/testing with per-layer shaders.
"""

import argparse
import numpy as np
import torch
from pathlib import Path


def export_layer_shader(layer_idx, weights, kernel_size, output_dir, mip_level=0, is_output_layer=False):
    """Generate WGSL compute shader for a single CNN layer.

    Args:
        layer_idx: Layer index (0, 1, 2, ...)
        weights: (4, 12, k, k) weight tensor (uniform 12D→4D)
        kernel_size: Kernel size (3, 5, etc.)
        output_dir: Output directory path
        mip_level: Mip level used for p0-p3 (0=original, 1=half, etc.)
        is_output_layer: True if this is the final RGBA output layer
    """
    weights_flat = weights.flatten()
    weights_f16 = weights_flat.astype(np.float16)
    weights_f32 = weights_f16.astype(np.float32)  # WGSL stores as f32 literals

    # Format weights as WGSL array
    weights_str = ",\n  ".join(
        ", ".join(f"{w:.6f}" for w in weights_f32[i:i+8])
        for i in range(0, len(weights_f32), 8)
    )

    radius = kernel_size // 2
    if is_output_layer:
        activation = "output[c] = clamp(sum, 0.0, 1.0);  // Output layer"
    elif layer_idx == 0:
        activation = "output[c] = clamp(sum, 0.0, 1.0);  // Layer 0: clamp [0,1]"
    else:
        activation = "output[c] = max(0.0, sum);  // Middle layers: ReLU"

    shader_code = f"""// CNN v2 Layer {layer_idx} - Auto-generated (uniform 12D→4D)
// Kernel: {kernel_size}×{kernel_size}, In: 12D (4 prev + 8 static), Out: 4D
// Mip level: {mip_level} (p0-p3 features)

const KERNEL_SIZE: u32 = {kernel_size}u;
const IN_CHANNELS: u32 = 12u;  // 4 (input/prev) + 8 (static)
const OUT_CHANNELS: u32 = 4u;   // Uniform output
const KERNEL_RADIUS: i32 = {radius};

// Weights quantized to float16 (stored as f32 in WGSL)
const weights: array<f32, {len(weights_f32)}> = array(
  {weights_str}
);

@group(0) @binding(0) var static_features: texture_2d<u32>;
@group(0) @binding(1) var layer_input: texture_2d<u32>;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;

fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {{
  let packed = textureLoad(static_features, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  let v2 = unpack2x16float(packed.z);
  let v3 = unpack2x16float(packed.w);
  return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}}

fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {{
  let packed = textureLoad(layer_input, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}}

fn pack_channels(values: vec4<f32>) -> vec4<u32> {{
  return vec4<u32>(
    pack2x16float(vec2<f32>(values.x, values.y)),
    pack2x16float(vec2<f32>(values.z, values.w)),
    0u,  // Unused
    0u   // Unused
  );
}}

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
  let coord = vec2<i32>(id.xy);
  let dims = textureDimensions(static_features);

  if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {{
    return;
  }}

  // Load static features (always available)
  let static_feat = unpack_static_features(coord);

  // Convolution: 12D input (4 prev + 8 static) → 4D output
  var output: vec4<f32> = vec4<f32>(0.0);
  for (var c: u32 = 0u; c < 4u; c++) {{
    var sum: f32 = 0.0;

    for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {{
      for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {{
        let sample_coord = coord + vec2<i32>(kx, ky);

        // Border handling (clamp)
        let clamped = vec2<i32>(
          clamp(sample_coord.x, 0, i32(dims.x) - 1),
          clamp(sample_coord.y, 0, i32(dims.y) - 1)
        );

        // Load features at this spatial location
        let static_local = unpack_static_features(clamped);
        let layer_local = unpack_layer_channels(clamped);  // 4D

        // Weight index calculation
        let ky_idx = u32(ky + KERNEL_RADIUS);
        let kx_idx = u32(kx + KERNEL_RADIUS);
        let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;

        // Accumulate: previous/input channels (4D)
        for (var i: u32 = 0u; i < 4u; i++) {{
          let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
                     i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
          sum += weights[w_idx] * layer_local[i];
        }}

        // Accumulate: static features (8D)
        for (var i: u32 = 0u; i < 8u; i++) {{
          let w_idx = c * 12u * KERNEL_SIZE * KERNEL_SIZE +
                     (4u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
          sum += weights[w_idx] * static_local[i];
        }}
      }}
    }}

    {activation}
  }}

  // Pack and store
  textureStore(output_tex, coord, pack_channels(output));
}}
"""

    output_path = Path(output_dir) / "cnn_v2" / f"cnn_v2_layer_{layer_idx}.wgsl"
    output_path.write_text(shader_code)
    print(f"  → {output_path}")


def export_checkpoint(checkpoint_path, output_dir):
    """Export PyTorch checkpoint to WGSL shaders.

    Args:
        checkpoint_path: Path to .pth checkpoint
        output_dir: Output directory for shaders
    """
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    config = checkpoint['config']

    kernel_size = config.get('kernel_size', 3)
    num_layers = config.get('num_layers', 3)
    mip_level = config.get('mip_level', 0)

    print(f"Configuration:")
    print(f"  Kernel size: {kernel_size}×{kernel_size}")
    print(f"  Layers: {num_layers}")
    print(f"  Mip level: {mip_level} (p0-p3 features)")
    print(f"  Architecture: uniform 12D→4D")

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nExporting shaders to {output_dir}/")

    # All layers uniform: 12D→4D
    for i in range(num_layers):
        layer_key = f'layers.{i}.weight'
        if layer_key not in state_dict:
            raise ValueError(f"Missing weights for layer {i}: {layer_key}")

        layer_weights = state_dict[layer_key].detach().numpy()
        is_output = (i == num_layers - 1)

        export_layer_shader(
            layer_idx=i,
            weights=layer_weights,
            kernel_size=kernel_size,
            output_dir=output_dir,
            mip_level=mip_level,
            is_output_layer=is_output
        )

    print(f"\nExport complete! Generated {num_layers} shader files.")


def main():
    parser = argparse.ArgumentParser(description='Export CNN v2 checkpoint to WGSL shaders')
    parser.add_argument('checkpoint', type=str, help='Path to checkpoint .pth file')
    parser.add_argument('--output-dir', type=str, default='workspaces/main/shaders',
                        help='Output directory for shaders')

    args = parser.parse_args()
    export_checkpoint(args.checkpoint, args.output_dir)


if __name__ == '__main__':
    main()