summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn/cnn_conv5x5.wgsl
blob: bd9abfacfa86766893c5bf58779b0d83d39a380f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// 5x5 convolution with 25 samples
// Applies mat4 weights per sample

fn cnn_conv5x5(
  tex: texture_2d<f32>,
  samp: sampler,
  uv: vec2<f32>,
  resolution: vec2<f32>,
  weights: array<mat4x4<f32>, 25>,
  bias: vec4<f32>
) -> vec4<f32> {
  let step = 1.0 / resolution;
  var sum = bias;
  var idx = 0;

  for (var dy = -2; dy <= 2; dy++) {
    for (var dx = -2; dx <= 2; dx++) {
      let offset = vec2<f32>(f32(dx), f32(dy)) * step;
      let sample = textureSample(tex, samp, uv + offset);
      sum += weights[idx] * sample;
      idx++;
    }
  }

  return sum;
}

fn cnn_conv5x5_with_coord(
  tex: texture_2d<f32>,
  samp: sampler,
  uv: vec2<f32>,
  resolution: vec2<f32>,
  rgba_weights: array<mat4x4<f32>, 25>,
  coord_weights: mat2x4<f32>,
  bias: vec4<f32>
) -> vec4<f32> {
  let step = 1.0 / resolution;
  var sum = bias;

  sum += coord_weights * uv;

  var idx = 0;
  for (var dy = -2; dy <= 2; dy++) {
    for (var dx = -2; dx <= 2; dx++) {
      let offset = vec2<f32>(f32(dx), f32(dy)) * step;
      let rgba = textureSample(tex, samp, uv + offset);
      sum += rgba_weights[idx] * rgba;
      idx++;
    }
  }

  return sum;
}