1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
// CNN v2 Compute Shader - Uniform 12D→4D Architecture
// All layers: input/previous (4D) + static (8D) = 12D → 4 channels
// Storage buffer weights, ping-pong execution
// Per-layer kernel sizes supported via LayerParams
// Push constants for layer parameters (passed per dispatch)
struct LayerParams {
kernel_size: u32,
in_channels: u32,
out_channels: u32,
weight_offset: u32, // Offset in f16 units
is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu)
blend_amount: f32, // [0,1] blend with original
}
@group(0) @binding(0) var static_features: texture_2d<u32>; // 8D static features (p0-p3 + spatial)
@group(0) @binding(1) var layer_input: texture_2d<u32>; // 4D previous/input (RGBD or prev layer)
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // 4D output
@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
@group(0) @binding(4) var<uniform> params: LayerParams;
@group(0) @binding(5) var original_input: texture_2d<f32>; // Original RGB for blending
fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
let packed = textureLoad(static_features, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
let v2 = unpack2x16float(packed.z);
let v3 = unpack2x16float(packed.w);
return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}
fn unpack_layer_channels(coord: vec2<i32>) -> vec4<f32> {
let packed = textureLoad(layer_input, coord, 0);
let v0 = unpack2x16float(packed.x);
let v1 = unpack2x16float(packed.y);
return vec4<f32>(v0.x, v0.y, v1.x, v1.y);
}
fn pack_channels(values: vec4<f32>) -> vec4<u32> {
return vec4<u32>(
pack2x16float(vec2<f32>(values.x, values.y)),
pack2x16float(vec2<f32>(values.z, values.w)),
0u, // Unused
0u // Unused
);
}
// Get weight from storage buffer (f16 packed as u32 pairs)
// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction
fn get_weight(idx: u32) -> f32 {
// Skip header (16 bytes = 4 u32) and layer info
// Weights start after header + layer_info, but weight_offset already accounts for this
let pair_idx = idx / 2u;
let packed = weights_buffer[pair_idx];
let unpacked = unpack2x16float(packed);
return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
}
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let coord = vec2<i32>(id.xy);
let dims = textureDimensions(static_features);
if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
return;
}
let kernel_size = params.kernel_size;
let in_channels = params.in_channels; // Always 12 (4 prev + 8 static)
let out_channels = params.out_channels; // Always 4
let weight_offset = params.weight_offset;
let is_output = params.is_output_layer != 0u;
let kernel_radius = i32(kernel_size / 2u);
// Load static features (8D) and previous/input layer (4D)
let static_feat = unpack_static_features(coord);
// Convolution: 12D input → 4D output
var output: vec4<f32> = vec4<f32>(0.0);
for (var c: u32 = 0u; c < 4u; c++) {
var sum: f32 = 0.0;
// Convolve over kernel
for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
let sample_coord = coord + vec2<i32>(kx, ky);
// Border handling (clamp)
let clamped = vec2<i32>(
clamp(sample_coord.x, 0, i32(dims.x) - 1),
clamp(sample_coord.y, 0, i32(dims.y) - 1)
);
// Load features at this spatial location
let static_local = unpack_static_features(clamped);
let layer_local = unpack_layer_channels(clamped); // 4D
// Weight index calculation
let ky_idx = u32(ky + kernel_radius);
let kx_idx = u32(kx + kernel_radius);
let spatial_idx = ky_idx * kernel_size + kx_idx;
// Accumulate: previous/input channels (4D)
for (var i: u32 = 0u; i < 4u; i++) {
let w_idx = weight_offset +
c * 12u * kernel_size * kernel_size +
i * kernel_size * kernel_size + spatial_idx;
sum += get_weight(w_idx) * layer_local[i];
}
// Accumulate: static features (8D)
for (var i: u32 = 0u; i < 8u; i++) {
let w_idx = weight_offset +
c * 12u * kernel_size * kernel_size +
(4u + i) * kernel_size * kernel_size + spatial_idx;
sum += get_weight(w_idx) * static_local[i];
}
}
}
// Activation
if (is_output) {
output[c] = clamp(sum, 0.0, 1.0);
} else {
output[c] = max(0.0, sum); // ReLU
}
}
// Blend with original on final layer
if (is_output) {
let original = textureLoad(original_input, coord, 0).rgb;
let result_rgb = vec3<f32>(output.x, output.y, output.z);
let blended = mix(original, result_rgb, params.blend_amount);
output.x = blended.r;
output.y = blended.g;
output.z = blended.b;
}
textureStore(output_tex, coord, pack_channels(output));
}
|