summaryrefslogtreecommitdiff
path: root/cnn_v3/shaders/gbuf_view.wgsl
blob: 6a812e69aa3b1feb2e29faf11316aad9697c943f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// G-buffer channel visualization — 4×5 grid of 20 feature channels.
// Takes feat_tex0 (rgba32uint, ch 0-7 f16) and feat_tex1 (rgba32uint, ch 8-19 unorm8).
// Outputs tiled channel view to a standard rgba8unorm render target.
//
// Channel layout (row×col):
//   Row 0: ch0(alb.r)  ch1(alb.g)  ch2(alb.b)  ch3(nrm.x)
//   Row 1: ch4(nrm.y)  ch5(depth)  ch6(dzdx)   ch7(dzdy)
//   Row 2: ch8(matid)  ch9(prv.r)  ch10(prv.g) ch11(prv.b)
//   Row 3: ch12(m1.r)  ch13(m1.g)  ch14(m1.b)  ch15(m2.r)
//   Row 4: ch16(m2.g)  ch17(m2.b)  ch18(dif)   ch19(trns)

#include "debug/debug_print"

struct GBufViewUniforms { resolution: vec2f }

@group(0) @binding(0) var feat0:   texture_2d<u32>;
@group(0) @binding(1) var feat1:   texture_2d<u32>;
@group(0) @binding(2) var<uniform> u: GBufViewUniforms;

@vertex
fn vs_main(@builtin(vertex_index) vid: u32) -> @builtin(position) vec4f {
    var corners = array<vec2f, 3>(
        vec2f(-1.0, -1.0), vec2f(3.0, -1.0), vec2f(-1.0, 3.0));
    return vec4f(corners[vid], 0.0, 1.0);
}

@fragment
fn fs_main(@builtin(position) pos: vec4f) -> @location(0) vec4f {
    let uv  = pos.xy / u.resolution;

    let COLS = 4.0;
    let ROWS = 5.0;
    let col  = u32(uv.x * COLS);
    let row  = u32(uv.y * ROWS);
    let ch   = row * 4u + col;

    if (col >= 4u || ch >= 20u) {
        return vec4f(0.05, 0.05, 0.05, 1.0);
    }

    // 1-pixel grid lines (thin border per cell)
    let lx = fract(uv.x * COLS);
    let ly = fract(uv.y * ROWS);
    if (lx < 0.005 || lx > 0.995 || ly < 0.005 || ly > 0.995) {
        return vec4f(0.25, 0.25, 0.25, 1.0);
    }

    // Map local UV to texel coordinate
    let dim = vec2i(textureDimensions(feat0));
    let tc  = clamp(vec2i(vec2f(lx, ly) * vec2f(dim)), vec2i(0), dim - vec2i(1));

    var v: f32 = 0.0;

    if (ch < 8u) {
        // feat0: 4 × pack2x16float — each u32 component holds two f16 values
        let t       = textureLoad(feat0, tc, 0);
        let pair_idx = ch >> 1u;
        let sub      = ch & 1u;
        var p: vec2f;
        if      (pair_idx == 0u) { p = unpack2x16float(t.x); }
        else if (pair_idx == 1u) { p = unpack2x16float(t.y); }
        else if (pair_idx == 2u) { p = unpack2x16float(t.z); }
        else                     { p = unpack2x16float(t.w); }
        v = select(p.y, p.x, sub == 0u);
    } else {
        // feat1: 3 × pack4x8unorm — components .x/.y/.z hold 4 u8 values each
        let t        = textureLoad(feat1, tc, 0);
        let ch1      = ch - 8u;
        let comp_idx = ch1 / 4u;
        let sub      = ch1 % 4u;
        var bytes: vec4f;
        if      (comp_idx == 0u) { bytes = unpack4x8unorm(t.x); }
        else if (comp_idx == 1u) { bytes = unpack4x8unorm(t.y); }
        else                     { bytes = unpack4x8unorm(t.z); }
        var ba = array<f32, 4>(bytes.x, bytes.y, bytes.z, bytes.w);
        v = ba[sub];
    }

    // Channel-specific normalization for display clarity
    var disp: f32;
    if (ch <= 2u) {
        // Albedo: already [0,1]
        disp = clamp(v, 0.0, 1.0);
    } else if (ch == 3u || ch == 4u) {
        // Normals oct-encoded in [-1,1] → remap to [0,1]
        disp = clamp(v * 0.5 + 0.5, 0.0, 1.0);
    } else if (ch == 5u) {
        // Depth [0,1]: invert so near=white, far=dark
        disp = clamp(1.0 - v, 0.0, 1.0);
    } else if (ch == 6u || ch == 7u) {
        // Depth gradients (signed, small values): amplify × 20 + 0.5 for visibility
        disp = clamp(v * 20.0 + 0.5, 0.0, 1.0);
    } else {
        // Everything else: clamp to [0,1]
        disp = clamp(v, 0.0, 1.0);
    }

    var out = vec4f(disp, disp, disp, 1.0);

    // Label at top-left of each tile
    let tile_w = u.resolution.x / 4.0;
    let tile_h = u.resolution.y / 5.0;
    let origin = vec2f(f32(col) * tile_w + 4.0, f32(row) * tile_h + 4.0);
    switch ch {
        case  0u: { out = debug_str(out, pos.xy, origin, vec4u(0x616C622Eu, 0x72000000u, 0u, 0u), 5u); } // alb.r
        case  1u: { out = debug_str(out, pos.xy, origin, vec4u(0x616C622Eu, 0x67000000u, 0u, 0u), 5u); } // alb.g
        case  2u: { out = debug_str(out, pos.xy, origin, vec4u(0x616C622Eu, 0x62000000u, 0u, 0u), 5u); } // alb.b
        case  3u: { out = debug_str(out, pos.xy, origin, vec4u(0x6E726D2Eu, 0x78000000u, 0u, 0u), 5u); } // nrm.x
        case  4u: { out = debug_str(out, pos.xy, origin, vec4u(0x6E726D2Eu, 0x79000000u, 0u, 0u), 5u); } // nrm.y
        case  5u: { out = debug_str(out, pos.xy, origin, vec4u(0x64657074u, 0x68000000u, 0u, 0u), 5u); } // depth
        case  6u: { out = debug_str(out, pos.xy, origin, vec4u(0x647A6478u, 0u, 0u, 0u), 4u); }          // dzdx
        case  7u: { out = debug_str(out, pos.xy, origin, vec4u(0x647A6479u, 0u, 0u, 0u), 4u); }          // dzdy
        case  8u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D617469u, 0x64000000u, 0u, 0u), 5u); } // matid
        case  9u: { out = debug_str(out, pos.xy, origin, vec4u(0x7072762Eu, 0x72000000u, 0u, 0u), 5u); } // prv.r
        case 10u: { out = debug_str(out, pos.xy, origin, vec4u(0x7072762Eu, 0x67000000u, 0u, 0u), 5u); } // prv.g
        case 11u: { out = debug_str(out, pos.xy, origin, vec4u(0x7072762Eu, 0x62000000u, 0u, 0u), 5u); } // prv.b
        case 12u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D312E72u, 0u, 0u, 0u), 4u); }          // m1.r
        case 13u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D312E67u, 0u, 0u, 0u), 4u); }          // m1.g
        case 14u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D312E62u, 0u, 0u, 0u), 4u); }          // m1.b
        case 15u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D322E72u, 0u, 0u, 0u), 4u); }          // m2.r
        case 16u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D322E67u, 0u, 0u, 0u), 4u); }          // m2.g
        case 17u: { out = debug_str(out, pos.xy, origin, vec4u(0x6D322E62u, 0u, 0u, 0u), 4u); }          // m2.b
        case 18u: { out = debug_str(out, pos.xy, origin, vec4u(0x64696600u, 0u, 0u, 0u), 3u); }          // dif
        default:  { out = debug_str(out, pos.xy, origin, vec4u(0x74726E73u, 0u, 0u, 0u), 4u); }          // trns
    }
    return out;
}