summaryrefslogtreecommitdiff
path: root/cnn_v3/shaders/cnn_v3_common.wgsl
blob: 54b0f3dab2c9e503a0e6f7aa98ee15d2579cf8f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// CNN v3 shared helpers — included by all inference compute shaders.
// Requires the host shader to declare:
//   @group(?) @binding(?) var<storage, read> weights: array<u32>;

// Read one f16 value from the packed-f16 weights buffer.
// `base`  — weight_offset from Params (f16 index of the layer start)
// `idx`   — local f16 index within the layer (conv weight or bias)
fn get_w(base: u32, idx: u32) -> f32 {
    let i = base + idx;
    let v = unpack2x16float(weights[i >> 1u]);
    return select(v.y, v.x, (i & 1u) == 0u);
}

// Unpack 8 f16 channels from an rgba32uint texel (pack2x16float layout:
//   u32[0]=ch0|ch1, u32[1]=ch2|ch3, u32[2]=ch4|ch5, u32[3]=ch6|ch7)
fn unpack_8ch(tex: texture_2d<u32>, coord: vec2i) -> array<f32, 8> {
    let t  = textureLoad(tex, coord, 0);
    let v0 = unpack2x16float(t.x);
    let v1 = unpack2x16float(t.y);
    let v2 = unpack2x16float(t.z);
    let v3 = unpack2x16float(t.w);
    return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}