diff options
Diffstat (limited to 'cnn_v3/shaders/cnn_v3_common.wgsl')
| -rw-r--r-- | cnn_v3/shaders/cnn_v3_common.wgsl | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/cnn_v3/shaders/cnn_v3_common.wgsl b/cnn_v3/shaders/cnn_v3_common.wgsl new file mode 100644 index 0000000..54b0f3d --- /dev/null +++ b/cnn_v3/shaders/cnn_v3_common.wgsl @@ -0,0 +1,23 @@ +// CNN v3 shared helpers — included by all inference compute shaders. +// Requires the host shader to declare: +// @group(?) @binding(?) var<storage, read> weights: array<u32>; + +// Read one f16 value from the packed-f16 weights buffer. +// `base` — weight_offset from Params (f16 index of the layer start) +// `idx` — local f16 index within the layer (conv weight or bias) +fn get_w(base: u32, idx: u32) -> f32 { + let i = base + idx; + let v = unpack2x16float(weights[i >> 1u]); + return select(v.y, v.x, (i & 1u) == 0u); +} + +// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout: +// u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7) +fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> { + let t = textureLoad(tex, coord, 0); + let v0 = unpack2x16float(t.x); + let v1 = unpack2x16float(t.y); + let v2 = unpack2x16float(t.z); + let v3 = unpack2x16float(t.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} |
