1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
|
// CNN v2 Compute Shader - Storage Buffer Version
// Processes single layer per dispatch with weights from storage buffer
// Multi-layer execution handled by C++ with ping-pong buffers
// Push constants for layer parameters (passed per dispatch)
struct LayerParams {
kernel_size: u32,
in_channels: u32,
out_channels: u32,
weight_offset: u32, // Offset in f16 units
is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu)
}
@group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features
@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output (8-channel packed)
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output
@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
@group(0) @binding(4) var<uniform> params: LayerParams;
fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
let packed = textureLoad(static_features, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
let v2 = unpack2x16float(packed.z);
let v3 = unpack2x16float(packed.w);
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
let v2 = unpack2x16float(packed.z);
let v3 = unpack2x16float(packed.w);
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
return vec4<u32>(
pack2x16float(vec2<f32>(values[0], values[1])),
pack2x16float(vec2<f32>(values[2], values[3])),
pack2x16float(vec2<f32>(values[4], values[5])),
pack2x16float(vec2<f32>(values[6], values[7]))
);
}
// Get weight from storage buffer (f16 packed as u32 pairs)
// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction
fn get_weight(idx: u32) -> f32 {
// Skip header (16 bytes = 4 u32) and layer info
// Weights start after header + layer_info, but weight_offset already accounts for this
let pair_idx = idx / 2u;
let packed = weights_buffer[pair_idx];
let unpacked = unpack2x16float(packed);
return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
}
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let coord = vec2<i32>(id.xy);
let dims = textureDimensions(static_features);
if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
return;
}
let kernel_size = params.kernel_size;
let in_channels = params.in_channels;
let out_channels = params.out_channels;
let weight_offset = params.weight_offset;
let is_output = params.is_output_layer != 0u;
let kernel_radius = i32(kernel_size / 2u);
// Load static features (always 8D)
let static_feat = unpack_static_features(coord);
// Convolution per output channel
var output: array<f32, 8>;
for (var c: u32 = 0u; c < out_channels && c < 8u; c++) {
var sum: f32 = 0.0;
// Convolve over kernel
for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
let sample_coord = coord + vec2<i32>(kx, ky);
// Border handling (clamp)
let clamped = vec2<i32>(
clamp(sample_coord.x, 0, i32(dims.x) - 1),
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
// Load input features at this spatial location
let static_local = unpack_static_features(clamped);
let layer_local = unpack_layer_channels(clamped);
// Weight index calculation
let ky_idx = u32(ky + kernel_radius);
let kx_idx = u32(kx + kernel_radius);
let spatial_idx = ky_idx * kernel_size + kx_idx;
// Accumulate: static features (always 8 channels)
for (var i: u32 = 0u; i < 8u; i++) {
let w_idx = weight_offset +
c * in_channels * kernel_size * kernel_size +
i * kernel_size * kernel_size + spatial_idx;
sum += get_weight(w_idx) * static_local[i];
}
// Accumulate: previous layer channels (in_channels - 8)
let prev_channels = in_channels - 8u;
for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) {
let w_idx = weight_offset +
c * in_channels * kernel_size * kernel_size +
(8u + i) * kernel_size * kernel_size + spatial_idx;
sum += get_weight(w_idx) * layer_local[i];
}
}
}
// Activation
if (is_output) {
output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation
} else {
output[c] = max(0.0, sum); // ReLU
}
}
// Zero unused channels
for (var c: u32 = out_channels; c < 8u; c++) {
output[c] = 0.0;
}
textureStore(output_tex, coord, pack_channels(output));
}
|