summaryrefslogtreecommitdiff
path: root/workspaces/main/shaders/cnn_v2/cnn_v2_static.wgsl
blob: 29acddd6cd74deaed1866b3281085765df0ed5f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
// CNN v2 Static Features Compute Shader
// Generates 8D parametric features: [p0, p1, p2, p3, uv.x, uv.y, sin10_x, bias]
// p0-p3: Parametric features from specified mip level (0=mip0, 1=mip1, 2=mip2, 3=mip3)
// Note: Input image RGBD (mip0) fed separately to Layer 0
//
// TODO: Binary format should support arbitrary layout and ordering for feature vector (7D).
//       Current layout is hardcoded. Future versions should allow runtime-specified
//       feature combinations (e.g., [R, G, B, dx, dy, uv_x, bias] or custom encodings).

struct StaticFeatureParams {
  mip_level: u32,
  padding: vec3<u32>,
}

@group(0) @binding(0) var input_tex: texture_2d<f32>;
@group(0) @binding(1) var input_tex_mip1: texture_2d<f32>;
@group(0) @binding(2) var input_tex_mip2: texture_2d<f32>;
@group(0) @binding(3) var depth_tex: texture_2d<f32>;
@group(0) @binding(4) var output_tex: texture_storage_2d<rgba32uint, write>;
@group(0) @binding(5) var<uniform> params: StaticFeatureParams;

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
  let coord = vec2<i32>(id.xy);
  let dims = textureDimensions(input_tex);

  if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) {
    return;
  }

  // Parametric features (p0-p3) - sample from specified mip level
  var rgba: vec4<f32>;
  if (params.mip_level == 0u) {
    rgba = textureLoad(input_tex, coord, 0);
  } else if (params.mip_level == 1u) {
    rgba = textureLoad(input_tex_mip1, coord, 0);
  } else if (params.mip_level == 2u) {
    rgba = textureLoad(input_tex_mip2, coord, 0);
  } else {
    // Mip 3 or higher: use mip 2 as fallback
    rgba = textureLoad(input_tex_mip2, coord, 0);
  }

  let p0 = rgba.r;
  let p1 = rgba.g;
  let p2 = rgba.b;
  let p3 = textureLoad(depth_tex, coord, 0).r;

  // UV coordinates (normalized [0,1], bottom-left origin)
  let uv_x = f32(coord.x) / f32(dims.x);
  let uv_y = 1.0 - (f32(coord.y) / f32(dims.y));

  // Multi-frequency position encoding
  let sin10_x = sin(10.0 * uv_x);

  // Bias dimension (always 1.0)
  let bias = 1.0;

  // Pack 8×f16 into 4×u32 (rgba32uint)
  // [p0, p1, p2, p3, uv_x, uv_y, sin10_x, bias]
  let packed = vec4<u32>(
    pack2x16float(vec2<f32>(p0, p1)),
    pack2x16float(vec2<f32>(p2, p3)),
    pack2x16float(vec2<f32>(uv_x, uv_y)),
    pack2x16float(vec2<f32>(sin10_x, bias))
  );

  textureStore(output_tex, coord, packed);
}