summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn_v2_compute.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'workspaces/main/shaders/cnn_v2_compute.wgsl')
-rw-r--r--workspaces/main/shaders/cnn_v2_compute.wgsl137
1 files changed, 137 insertions, 0 deletions
diff --git a/workspaces/main/shaders/cnn_v2_compute.wgsl b/workspaces/main/shaders/cnn_v2_compute.wgsl
new file mode 100644
index 0000000..b19a692
--- /dev/null
+++ b/workspaces/main/shaders/cnn_v2_compute.wgsl
@@ -0,0 +1,137 @@
+// CNN v2 Compute Shader - Storage Buffer Version
+// Processes single layer per dispatch with weights from storage buffer
+// Multi-layer execution handled by C++ with ping-pong buffers
+
+// Push constants for layer parameters (passed per dispatch)
+struct LayerParams {
+ kernel_size: u32,
+ in_channels: u32,
+ out_channels: u32,
+ weight_offset: u32, // Offset in f16 units
+ is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu)
+}
+
+@group(0) @binding(0) var static_features: texture_2d<u32>; // 8-channel static features
+@group(0) @binding(1) var layer_input: texture_2d<u32>; // Previous layer output (8-channel packed)
+@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; // Current layer output
+@group(0) @binding(3) var<storage, read> weights_buffer: array<u32>; // Packed f16 weights
+@group(0) @binding(4) var<uniform> params: LayerParams;
+
+fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(static_features, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
+ let packed = textureLoad(layer_input, coord, 0);
+ let v0 = unpack2x16float(packed.x);
+ let v1 = unpack2x16float(packed.y);
+ let v2 = unpack2x16float(packed.z);
+ let v3 = unpack2x16float(packed.w);
+ return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
+}
+
+fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
+ return vec4<u32>(
+ pack2x16float(vec2<f32>(values[0], values[1])),
+ pack2x16float(vec2<f32>(values[2], values[3])),
+ pack2x16float(vec2<f32>(values[4], values[5])),
+ pack2x16float(vec2<f32>(values[6], values[7]))
+ );
+}
+
+// Get weight from storage buffer (f16 packed as u32 pairs)
+// Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16]
+// TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction
+fn get_weight(idx: u32) -> f32 {
+ // Skip header (16 bytes = 4 u32) and layer info
+ // Weights start after header + layer_info, but weight_offset already accounts for this
+ let pair_idx = idx / 2u;
+ let packed = weights_buffer[pair_idx];
+ let unpacked = unpack2x16float(packed);
+ return select(unpacked.y, unpacked.x, (idx & 1u) == 0u);
+}
+
+@compute @workgroup_size(8, 8)
+fn main(@builtin(global_invocation_id) id: vec3<u32>) {
+ let coord = vec2<i32>(id.xy);
+ let dims = textureDimensions(static_features);
+
+ if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
+ return;
+ }
+
+ let kernel_size = params.kernel_size;
+ let in_channels = params.in_channels;
+ let out_channels = params.out_channels;
+ let weight_offset = params.weight_offset;
+ let is_output = params.is_output_layer != 0u;
+
+ let kernel_radius = i32(kernel_size / 2u);
+
+ // Load static features (always 8D)
+ let static_feat = unpack_static_features(coord);
+
+ // Convolution per output channel
+ var output: array<f32, 8>;
+ for (var c: u32 = 0u; c < out_channels && c < 8u; c++) {
+ var sum: f32 = 0.0;
+
+ // Convolve over kernel
+ for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) {
+ for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) {
+ let sample_coord = coord + vec2<i32>(kx, ky);
+
+ // Border handling (clamp)
+ let clamped = vec2<i32>(
+ clamp(sample_coord.x, 0, i32(dims.x) - 1),
+ clamp(sample_coord.y, 0, i32(dims.y) - 1)
+ );
+
+ // Load input features at this spatial location
+ let static_local = unpack_static_features(clamped);
+ let layer_local = unpack_layer_channels(clamped);
+
+ // Weight index calculation
+ let ky_idx = u32(ky + kernel_radius);
+ let kx_idx = u32(kx + kernel_radius);
+ let spatial_idx = ky_idx * kernel_size + kx_idx;
+
+ // Accumulate: static features (always 8 channels)
+ for (var i: u32 = 0u; i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ i * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * static_local[i];
+ }
+
+ // Accumulate: previous layer channels (in_channels - 8)
+ let prev_channels = in_channels - 8u;
+ for (var i: u32 = 0u; i < prev_channels && i < 8u; i++) {
+ let w_idx = weight_offset +
+ c * in_channels * kernel_size * kernel_size +
+ (8u + i) * kernel_size * kernel_size + spatial_idx;
+ sum += get_weight(w_idx) * layer_local[i];
+ }
+ }
+ }
+
+ // Activation
+ if (is_output) {
+ output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation
+ } else {
+ output[c] = max(0.0, sum); // ReLU
+ }
+ }
+
+ // Zero unused channels
+ for (var c: u32 = out_channels; c < 8u; c++) {
+ output[c] = 0.0;
+ }
+
+ textureStore(output_tex, coord, pack_channels(output));
+}