// CNN v3 — Decoder level 1 // NearestUp2x(bottleneck) + cat(enc1_skip) -> Conv(32->8, 3x3) + FiLM + ReLU // // Inputs: bn_tex_lo (rgba32uint, 8xf16) quarter-res ch 0-7 // bn_tex_hi (rgba32uint, 8xf16) quarter-res ch 8-15 // enc1_tex_lo (rgba32uint, 8xf16) half-res skip ch 0-7 // enc1_tex_hi (rgba32uint, 8xf16) half-res skip ch 8-15 // Output: dec1_out (rgba32uint, 8xf16) half-res // // Weight layout (f16, OIHW + bias): // [0 .. 32*8*9) conv: w[out][in][ky][kx] (in=32: 16 bn + 16 enc1 skip) // [2304 .. +8) bias: b[out] #include "cnn_v3/common" const DEC1_IN: u32 = 32u; const DEC1_OUT: u32 = 8u; struct Params { weight_offset: u32, _pad: vec3u, gamma_lo: vec4f, gamma_hi: vec4f, beta_lo: vec4f, beta_hi: vec4f, } @group(0) @binding(0) var bn_tex_lo: texture_2d; @group(0) @binding(1) var bn_tex_hi: texture_2d; @group(0) @binding(2) var enc1_tex_lo: texture_2d; @group(0) @binding(3) var enc1_tex_hi: texture_2d; @group(0) @binding(4) var weights: array; @group(0) @binding(5) var params: Params; @group(0) @binding(6) var dec1_out: texture_storage_2d; fn film_gamma(o: u32) -> f32 { if (o < 4u) { return params.gamma_lo[o]; } return params.gamma_hi[o - 4u]; } fn film_beta(o: u32) -> f32 { if (o < 4u) { return params.beta_lo[o]; } return params.beta_hi[o - 4u]; } // Load 32ch: [bn_nearest_up(16ch), enc1_skip(16ch)] fn load_dec1_concat(hcoord: vec2i, half_dims: vec2i) -> array { var r: array; if (hcoord.x < 0 || hcoord.y < 0 || hcoord.x >= half_dims.x || hcoord.y >= half_dims.y) { return r; } let quart_dims = half_dims / 2; let qc = clamp(hcoord / 2, vec2i(0), quart_dims - vec2i(1)); let blo = unpack_8ch(bn_tex_lo, qc); let bhi = unpack_8ch(bn_tex_hi, qc); let slo = unpack_8ch(enc1_tex_lo, hcoord); let shi = unpack_8ch(enc1_tex_hi, hcoord); for (var i: u32 = 0u; i < 8u; i++) { r[i] = blo[i]; r[i + 8u] = bhi[i]; r[i + 16u] = slo[i]; r[i + 24u] = shi[i]; } return r; } @compute @workgroup_size(8, 8) fn dec1_main(@builtin(global_invocation_id) id: vec3u) { let half_dims = vec2i(textureDimensions(enc1_tex_lo)); let coord = vec2i(id.xy); if (coord.x >= half_dims.x || coord.y >= half_dims.y) { return; } let wo = params.weight_offset; var out: array; for (var o: u32 = 0u; o < DEC1_OUT; o++) { var sum = get_w(wo, DEC1_OUT * DEC1_IN * 9u + o); // bias for (var ky: i32 = -1; ky <= 1; ky++) { for (var kx: i32 = -1; kx <= 1; kx++) { let feat = load_dec1_concat(coord + vec2i(kx, ky), half_dims); let ki = u32(ky + 1) * 3u + u32(kx + 1); for (var i: u32 = 0u; i < DEC1_IN; i++) { sum += get_w(wo, o * DEC1_IN * 9u + i * 9u + ki) * feat[i]; } } } out[o] = max(0.0, film_gamma(o) * sum + film_beta(o)); } textureStore(dec1_out, coord, vec4u( pack2x16float(vec2f(out[0], out[1])), pack2x16float(vec2f(out[2], out[3])), pack2x16float(vec2f(out[4], out[5])), pack2x16float(vec2f(out[6], out[7])) )); }