// CNN v2 Compute Shader - Uniform 12D→4D Architecture // All layers: input/previous (4D) + static (8D) = 12D → 4 channels // Storage buffer weights, ping-pong execution // Per-layer kernel sizes supported via LayerParams // Push constants for layer parameters (passed per dispatch) struct LayerParams { kernel_size: u32, in_channels: u32, out_channels: u32, weight_offset: u32, // Offset in f16 units is_output_layer: u32, // 1 if final layer (sigmoid), 0 otherwise (relu) blend_amount: f32, // [0,1] blend with original is_layer_0: u32, // 1 if first layer (clamp [0,1]), 0 otherwise } @group(0) @binding(0) var static_features: texture_2d; // 8D static features (p0-p3 + spatial) @group(0) @binding(1) var layer_input: texture_2d; // 4D previous/input (RGBD or prev layer) @group(0) @binding(2) var output_tex: texture_storage_2d; // 4D output @group(0) @binding(3) var weights_buffer: array; // Packed f16 weights @group(0) @binding(4) var params: LayerParams; @group(0) @binding(5) var original_input: texture_2d; // Original RGB for blending fn unpack_static_features(coord: vec2) -> array { let packed = textureLoad(static_features, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); let v2 = unpack2x16float(packed.z); let v3 = unpack2x16float(packed.w); return array(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); } fn unpack_layer_channels(coord: vec2) -> vec4 { let packed = textureLoad(layer_input, coord, 0); let v0 = unpack2x16float(packed.x); let v1 = unpack2x16float(packed.y); return vec4(v0.x, v0.y, v1.x, v1.y); } fn pack_channels(values: vec4) -> vec4 { return vec4( pack2x16float(vec2(values.x, values.y)), pack2x16float(vec2(values.z, values.w)), 0u, // Unused 0u // Unused ); } // Get weight from storage buffer (f16 packed as u32 pairs) // Buffer layout: [header: 4 u32][layer_info: N×5 u32][weights: packed f16] // TODO: Support 8-bit quantized weights (4× per u32) for 2× size reduction fn get_weight(idx: u32) -> f32 { // Skip header (16 bytes = 4 u32) and layer info // Weights start after header + layer_info, but weight_offset already accounts for this let pair_idx = idx / 2u; let packed = weights_buffer[pair_idx]; let unpacked = unpack2x16float(packed); return select(unpacked.y, unpacked.x, (idx & 1u) == 0u); } @compute @workgroup_size(8, 8) fn main(@builtin(global_invocation_id) id: vec3) { let coord = vec2(id.xy); let dims = textureDimensions(static_features); if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { return; } let kernel_size = params.kernel_size; let in_channels = params.in_channels; // Always 12 (4 prev + 8 static) let out_channels = params.out_channels; // Always 4 let weight_offset = params.weight_offset; let is_output = params.is_output_layer != 0u; let kernel_radius = i32(kernel_size / 2u); // Load static features (8D) and previous/input layer (4D) let static_feat = unpack_static_features(coord); // Convolution: 12D input → 4D output var output: vec4 = vec4(0.0); for (var c: u32 = 0u; c < 4u; c++) { var sum: f32 = 0.0; // Convolve over kernel for (var ky: i32 = -kernel_radius; ky <= kernel_radius; ky++) { for (var kx: i32 = -kernel_radius; kx <= kernel_radius; kx++) { let sample_coord = coord + vec2(kx, ky); // Border handling (clamp) let clamped = vec2( clamp(sample_coord.x, 0, i32(dims.x) - 1), clamp(sample_coord.y, 0, i32(dims.y) - 1) ); // Load features at this spatial location let static_local = unpack_static_features(clamped); let layer_local = unpack_layer_channels(clamped); // 4D // Weight index calculation let ky_idx = u32(ky + kernel_radius); let kx_idx = u32(kx + kernel_radius); let spatial_idx = ky_idx * kernel_size + kx_idx; // Accumulate: previous/input channels (4D) for (var i: u32 = 0u; i < 4u; i++) { let w_idx = weight_offset + c * 12u * kernel_size * kernel_size + i * kernel_size * kernel_size + spatial_idx; sum += get_weight(w_idx) * layer_local[i]; } // Accumulate: static features (8D) for (var i: u32 = 0u; i < 8u; i++) { let w_idx = weight_offset + c * 12u * kernel_size * kernel_size + (4u + i) * kernel_size * kernel_size + spatial_idx; sum += get_weight(w_idx) * static_local[i]; } } } // Activation (matches train_cnn_v2.py) if (is_output) { output[c] = clamp(sum, 0.0, 1.0); // Output layer: clamp [0,1] } else if (params.is_layer_0 != 0u) { output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1] } else { output[c] = max(0.0, sum); // Middle layers: ReLU } } // Blend with original on final layer if (is_output) { let original = textureLoad(original_input, coord, 0).rgb; let result_rgb = vec3(output.x, output.y, output.z); let blended = mix(original, result_rgb, params.blend_amount); output.x = blended.r; output.y = blended.g; output.z = blended.b; } textureStore(output_tex, coord, pack_channels(output)); }