summaryrefslogtreecommitdiff
path: root/cnn_v3/shaders/cnn_v3_enc0.wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'cnn_v3/shaders/cnn_v3_enc0.wgsl')
-rw-r--r--cnn_v3/shaders/cnn_v3_enc0.wgsl75
1 files changed, 75 insertions, 0 deletions
diff --git a/cnn_v3/shaders/cnn_v3_enc0.wgsl b/cnn_v3/shaders/cnn_v3_enc0.wgsl
new file mode 100644
index 0000000..f52a167
--- /dev/null
+++ b/cnn_v3/shaders/cnn_v3_enc0.wgsl
@@ -0,0 +1,75 @@
+// CNN v3 — Encoder level 0
+// Conv(20->4, 3x3, zero-pad) + FiLM + ReLU
+//
+// Input: feat_tex0 (rgba32uint, 8xf16), feat_tex1 (rgba32uint, 12xu8) full-res
+// Output: enc0_out (rgba16float, 4ch) full-res
+//
+// Weight layout (f16, OIHW + bias):
+// [0 .. 20*4*9) conv: w[out][in][ky][kx]
+// [720 .. +4) bias: b[out]
+
+#include "cnn_v3/common"
+
+const ENC0_IN: u32 = 20u;
+const ENC0_OUT: u32 = 4u;
+
+struct Params {
+ weight_offset: u32,
+ _pad: vec3u,
+ gamma: vec4f,
+ beta: vec4f,
+}
+
+@group(0) @binding(0) var feat_tex0: texture_2d<u32>;
+@group(0) @binding(1) var feat_tex1: texture_2d<u32>;
+@group(0) @binding(2) var<storage, read> weights: array<u32>;
+@group(0) @binding(3) var<uniform> params: Params;
+@group(0) @binding(4) var enc0_out: texture_storage_2d<rgba16float, write>;
+
+// Unpack all 20 feature channels at coord. Returns zeros for OOB (zero-padding).
+fn load_feat(coord: vec2i, dims: vec2i) -> array<f32, 20> {
+ if (coord.x < 0 || coord.y < 0 || coord.x >= dims.x || coord.y >= dims.y) {
+ return array<f32, 20>(0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.);
+ }
+ let t0 = textureLoad(feat_tex0, coord, 0);
+ let t1 = textureLoad(feat_tex1, coord, 0);
+ let a = unpack2x16float(t0.x);
+ let b = unpack2x16float(t0.y);
+ let c = unpack2x16float(t0.z);
+ let d = unpack2x16float(t0.w);
+ let e = unpack4x8unorm(t1.x);
+ let f = unpack4x8unorm(t1.y);
+ let g = unpack4x8unorm(t1.z);
+ return array<f32, 20>(
+ a.x, a.y, b.x, b.y, c.x, c.y, d.x, d.y,
+ e.x, e.y, e.z, e.w,
+ f.x, f.y, f.z, f.w,
+ g.x, g.y, g.z, g.w
+ );
+}
+
+@compute @workgroup_size(8, 8)
+fn enc0_main(@builtin(global_invocation_id) id: vec3u) {
+ let coord = vec2i(id.xy);
+ let dims = vec2i(textureDimensions(feat_tex0));
+ if (coord.x >= dims.x || coord.y >= dims.y) { return; }
+
+ let wo = params.weight_offset;
+ var out: array<f32, ENC0_OUT>;
+
+ for (var o: u32 = 0u; o < ENC0_OUT; o++) {
+ var sum = get_w(wo, ENC0_OUT * ENC0_IN * 9u + o); // bias
+ for (var ky: i32 = -1; ky <= 1; ky++) {
+ for (var kx: i32 = -1; kx <= 1; kx++) {
+ let feat = load_feat(coord + vec2i(kx, ky), dims);
+ let ki = u32(ky + 1) * 3u + u32(kx + 1);
+ for (var i: u32 = 0u; i < ENC0_IN; i++) {
+ sum += get_w(wo, o * ENC0_IN * 9u + i * 9u + ki) * feat[i];
+ }
+ }
+ }
+ out[o] = max(0.0, params.gamma[o] * sum + params.beta[o]);
+ }
+
+ textureStore(enc0_out, coord, vec4f(out[0], out[1], out[2], out[3]));
+}