summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn_v2_layer_2.wgsl
blob: 2f9836a02fdae63fcbda0c0ac89e744e58ca1b07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// CNN v2 Layer 2 - Auto-generated
// Kernel: 3×3, In: 12, Out: 4

const KERNEL_SIZE: u32 = 3u;
const IN_CHANNELS: u32 = 12u;
const OUT_CHANNELS: u32 = 4u;
const KERNEL_RADIUS: i32 = 1;

// Weights quantized to float16 (stored as f32 in WGSL)
const weights: array<f32, 432> = array(
  0.030212, -0.041351, 0.053864, -0.025635, 0.099976, -0.016830, -0.068665, 0.112488,
  -0.069824, 0.030197, 0.020142, 0.101807, 0.061920, 0.022415, -0.025864, -0.056366,
  0.085571, -0.053650, 0.109802, 0.129272, 0.023438, 0.087341, 0.066284, 0.037079,
  -0.067566, 0.021530, -0.046814, 0.029343, -0.028534, 0.047150, -0.079346, -0.022675,
  -0.019669, -0.024185, 0.029587, 0.068970, 0.108826, 0.050598, -0.072144, 0.083008,
  -0.002201, 0.006275, 0.056396, 0.001884, 0.097168, -0.028503, -0.002499, 0.008919,
  -0.013771, -0.017502, -0.033478, 0.105530, 0.032898, 0.068726, -0.036285, -0.021011,
  -0.018250, 0.073914, 0.024277, 0.061066, 0.008682, -0.022766, 0.074219, 0.094421,
  0.050903, 0.072571, 0.117493, -0.033234, 0.067993, -0.008049, 0.046997, -0.064209,
  -0.381104, 0.107788, -0.213867, 0.145142, 0.514160, 0.407715, -0.317871, 0.249023,
  0.055634, -0.006294, -0.067444, 0.025131, 0.012939, -0.074158, -0.013741, -0.033020,
  0.026871, -0.007671, 0.089661, -0.003016, 0.029007, -0.038483, 0.045044, 0.104065,
  0.077148, 0.092468, -0.090027, -0.048126, 0.096863, -0.088013, 0.009483, 0.075012,
  -0.076843, -0.085449, -0.066040, 0.019165, -0.019958, 0.083496, 0.069275, -0.019714,
  0.027786, -0.042389, 0.054718, 0.010635, -0.071777, 0.029282, -0.003605, 0.113770,
  0.080994, 0.106079, 0.047333, -0.013733, 0.034760, 0.099365, -0.020813, 0.095886,
  0.052490, -0.049194, 0.047394, 0.072510, -0.030930, -0.003782, -0.038025, -0.019318,
  -0.047852, -0.043915, 0.026810, -0.041138, 0.038422, 0.009605, -0.080688, -0.019653,
  0.075256, -0.013817, -0.022400, 0.050629, 0.048462, 0.072998, -0.009109, 0.070923,
  0.079895, 0.071350, 0.002869, 0.081543, 0.037231, 0.020767, -0.017929, 0.042328,
  -0.075134, -0.010681, -0.009079, 0.057007, -0.040253, -0.025574, -0.041534, 0.105835,
  -0.039703, 0.032104, 0.076050, 0.070923, -0.013046, -0.054108, -0.024582, -0.033997,
  0.092285, 0.000525, 0.114685, 0.036926, -0.419434, 0.087891, -0.187866, 0.128906,
  0.665527, 0.268311, -0.337891, 0.195557, 0.140503, 0.014465, -0.043671, 0.031677,
  0.073059, 0.085144, 0.014290, -0.046967, 0.033356, 0.004177, 0.102844, 0.015259,
  0.026627, -0.005032, 0.111694, -0.010590, 0.029816, 0.108154, -0.072327, 0.056213,
  0.022903, 0.053772, 0.084473, -0.059845, -0.032776, -0.000015, -0.093872, -0.085815,
  0.081604, 0.069336, 0.034149, -0.067322, -0.020859, 0.120911, 0.077209, -0.016388,
  0.050140, -0.045563, -0.046326, 0.032623, -0.005009, 0.008003, 0.109192, 0.086548,
  0.096558, 0.118530, 0.035034, 0.110352, -0.041748, 0.009178, 0.049957, 0.084839,
  0.042053, -0.069153, -0.024796, -0.094604, -0.047028, -0.053802, 0.024979, 0.049591,
  -0.016373, -0.047607, -0.008797, -0.058868, 0.107178, 0.055695, 0.092407, 0.092346,
  0.053894, 0.054657, -0.039703, -0.073792, 0.041779, -0.044159, 0.099182, 0.037109,
  0.097778, 0.098206, -0.057831, -0.054016, -0.068604, -0.061584, -0.054382, 0.005268,
  0.096008, -0.007118, -0.063049, 0.059113, 0.076904, 0.045288, -0.055695, -0.052612,
  -0.022110, 0.049103, 0.095276, 0.014572, 0.064819, 0.014671, 0.029800, 0.066284,
  -0.383301, 0.071838, -0.207275, 0.099365, 0.640137, 0.393311, -0.334229, 0.275391,
  -0.013977, -0.025269, -0.007065, -0.033478, -0.017349, 0.026764, 0.005192, 0.093384,
  0.014313, 0.018906, 0.006962, 0.094849, 0.005390, 0.101624, -0.041199, 0.026245,
  0.027588, 0.062408, 0.033356, -0.010826, 0.067993, -0.054199, 0.076416, 0.023315,
  -0.002886, -0.112061, -0.041473, -0.012703, 0.016022, 0.010506, -0.021362, -0.037750,
  0.062927, 0.061920, 0.038177, -0.037201, -0.011620, 0.014015, -0.062164, -0.045441,
  -0.063416, -0.040100, 0.035950, 0.045563, -0.017227, -0.060547, -0.017593, 0.111877,
  0.121521, 0.073853, 0.023331, -0.012428, 0.018478, -0.010948, 0.030716, 0.043427,
  0.003117, -0.069092, 0.038361, -0.053497, 0.039154, -0.085754, 0.012642, -0.051208,
  0.022934, 0.127197, 0.117920, 0.074036, 0.083313, -0.061951, 0.079224, 0.091248,
  0.009132, 0.069946, 0.123474, 0.130127, 0.118835, 0.020874, -0.045380, -0.000111,
  0.111206, 0.054688, 0.008995, 0.085693, 0.005562, 0.103088, -0.034698, 0.119934,
  -0.067200, 0.065430, -0.021942, 0.089783, 0.033112, -0.025467, 0.040161, -0.052155,
  -0.048920, 0.031250, 0.112549, 0.122192, 0.126587, 0.180908, 0.194946, 0.121704,
  0.217529, 0.224243, 0.269287, 0.222656, 0.288086, 0.035492, 0.066711, -0.046600,
  0.085144, 0.013855, -0.065979, -0.083252, -0.058289, 0.104126, 0.013702, -0.018188,
  0.036591, 0.099854, 0.056061, 0.151855, 0.062134, 0.133789, 0.084045, 0.095825,
  0.036987, 0.022308, 0.070923, 0.031036, 0.101868, 0.062347, 0.141235, 0.066650
);

@group(0) @binding(0) var static_features: texture_2d<u32>;
@group(0) @binding(1) var layer_input: texture_2d<u32>;
@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>;

fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> {
  let packed = textureLoad(static_features, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  let v2 = unpack2x16float(packed.z);
  let v3 = unpack2x16float(packed.w);
  return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}

fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> {
  let packed = textureLoad(layer_input, coord, 0);
  let v0 = unpack2x16float(packed.x);
  let v1 = unpack2x16float(packed.y);
  let v2 = unpack2x16float(packed.z);
  let v3 = unpack2x16float(packed.w);
  return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y);
}

fn pack_channels(values: array<f32, 8>) -> vec4<u32> {
  return vec4<u32>(
    pack2x16float(vec2<f32>(values[0], values[1])),
    pack2x16float(vec2<f32>(values[2], values[3])),
    pack2x16float(vec2<f32>(values[4], values[5])),
    pack2x16float(vec2<f32>(values[6], values[7]))
  );
}

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
  let coord = vec2<i32>(id.xy);
  let dims = textureDimensions(static_features);

  if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
    return;
  }

  // Load static features (always available)
  let static_feat = unpack_static_features(coord);

  // Convolution
  var output: array<f32, OUT_CHANNELS>;
  for (var c: u32 = 0u; c < OUT_CHANNELS; c++) {
    var sum: f32 = 0.0;

    for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) {
      for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) {
        let sample_coord = coord + vec2<i32>(kx, ky);

        // Border handling (clamp)
        let clamped = vec2<i32>(
          clamp(sample_coord.x, 0, i32(dims.x) - 1),
          clamp(sample_coord.y, 0, i32(dims.y) - 1)
        );

        // Load input features
        let static_local = unpack_static_features(clamped);
        let layer_local = unpack_layer_channels(clamped);

        // Weight index calculation
        let ky_idx = u32(ky + KERNEL_RADIUS);
        let kx_idx = u32(kx + KERNEL_RADIUS);
        let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx;

        // Accumulate: static features (8D)
        for (var i: u32 = 0u; i < 8u; i++) {
          let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
                     i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
          sum += weights[w_idx] * static_local[i];
        }

        // Accumulate: layer input channels (if layer_idx > 0)
        let prev_channels = IN_CHANNELS - 8u;
        for (var i: u32 = 0u; i < prev_channels; i++) {
          let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE +
                     (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx;
          sum += weights[w_idx] * layer_local[i];
        }
      }
    }

    output[c] = clamp(sum, 0.0, 1.0);  // Sigmoid approximation
  }

  // Pack and store
  textureStore(output_tex, coord, pack_channels(output));
}