diff options
Diffstat (limited to 'workspaces/main/shaders/cnn_v2_layer_2.wgsl')
| -rw-r--r-- | workspaces/main/shaders/cnn_v2_layer_2.wgsl | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/workspaces/main/shaders/cnn_v2_layer_2.wgsl b/workspaces/main/shaders/cnn_v2_layer_2.wgsl new file mode 100644 index 0000000..2f9836a --- /dev/null +++ b/workspaces/main/shaders/cnn_v2_layer_2.wgsl @@ -0,0 +1,156 @@ +// CNN v2 Layer 2 - Auto-generated +// Kernel: 3×3, In: 12, Out: 4 + +const KERNEL_SIZE: u32 = 3u; +const IN_CHANNELS: u32 = 12u; +const OUT_CHANNELS: u32 = 4u; +const KERNEL_RADIUS: i32 = 1; + +// Weights quantized to float16 (stored as f32 in WGSL) +const weights: array<f32, 432> = array( + 0.030212, -0.041351, 0.053864, -0.025635, 0.099976, -0.016830, -0.068665, 0.112488, + -0.069824, 0.030197, 0.020142, 0.101807, 0.061920, 0.022415, -0.025864, -0.056366, + 0.085571, -0.053650, 0.109802, 0.129272, 0.023438, 0.087341, 0.066284, 0.037079, + -0.067566, 0.021530, -0.046814, 0.029343, -0.028534, 0.047150, -0.079346, -0.022675, + -0.019669, -0.024185, 0.029587, 0.068970, 0.108826, 0.050598, -0.072144, 0.083008, + -0.002201, 0.006275, 0.056396, 0.001884, 0.097168, -0.028503, -0.002499, 0.008919, + -0.013771, -0.017502, -0.033478, 0.105530, 0.032898, 0.068726, -0.036285, -0.021011, + -0.018250, 0.073914, 0.024277, 0.061066, 0.008682, -0.022766, 0.074219, 0.094421, + 0.050903, 0.072571, 0.117493, -0.033234, 0.067993, -0.008049, 0.046997, -0.064209, + -0.381104, 0.107788, -0.213867, 0.145142, 0.514160, 0.407715, -0.317871, 0.249023, + 0.055634, -0.006294, -0.067444, 0.025131, 0.012939, -0.074158, -0.013741, -0.033020, + 0.026871, -0.007671, 0.089661, -0.003016, 0.029007, -0.038483, 0.045044, 0.104065, + 0.077148, 0.092468, -0.090027, -0.048126, 0.096863, -0.088013, 0.009483, 0.075012, + -0.076843, -0.085449, -0.066040, 0.019165, -0.019958, 0.083496, 0.069275, -0.019714, + 0.027786, -0.042389, 0.054718, 0.010635, -0.071777, 0.029282, -0.003605, 0.113770, + 0.080994, 0.106079, 0.047333, -0.013733, 0.034760, 0.099365, -0.020813, 0.095886, + 0.052490, -0.049194, 0.047394, 0.072510, -0.030930, -0.003782, -0.038025, -0.019318, + -0.047852, -0.043915, 0.026810, -0.041138, 0.038422, 0.009605, -0.080688, -0.019653, + 0.075256, -0.013817, -0.022400, 0.050629, 0.048462, 0.072998, -0.009109, 0.070923, + 0.079895, 0.071350, 0.002869, 0.081543, 0.037231, 0.020767, -0.017929, 0.042328, + -0.075134, -0.010681, -0.009079, 0.057007, -0.040253, -0.025574, -0.041534, 0.105835, + -0.039703, 0.032104, 0.076050, 0.070923, -0.013046, -0.054108, -0.024582, -0.033997, + 0.092285, 0.000525, 0.114685, 0.036926, -0.419434, 0.087891, -0.187866, 0.128906, + 0.665527, 0.268311, -0.337891, 0.195557, 0.140503, 0.014465, -0.043671, 0.031677, + 0.073059, 0.085144, 0.014290, -0.046967, 0.033356, 0.004177, 0.102844, 0.015259, + 0.026627, -0.005032, 0.111694, -0.010590, 0.029816, 0.108154, -0.072327, 0.056213, + 0.022903, 0.053772, 0.084473, -0.059845, -0.032776, -0.000015, -0.093872, -0.085815, + 0.081604, 0.069336, 0.034149, -0.067322, -0.020859, 0.120911, 0.077209, -0.016388, + 0.050140, -0.045563, -0.046326, 0.032623, -0.005009, 0.008003, 0.109192, 0.086548, + 0.096558, 0.118530, 0.035034, 0.110352, -0.041748, 0.009178, 0.049957, 0.084839, + 0.042053, -0.069153, -0.024796, -0.094604, -0.047028, -0.053802, 0.024979, 0.049591, + -0.016373, -0.047607, -0.008797, -0.058868, 0.107178, 0.055695, 0.092407, 0.092346, + 0.053894, 0.054657, -0.039703, -0.073792, 0.041779, -0.044159, 0.099182, 0.037109, + 0.097778, 0.098206, -0.057831, -0.054016, -0.068604, -0.061584, -0.054382, 0.005268, + 0.096008, -0.007118, -0.063049, 0.059113, 0.076904, 0.045288, -0.055695, -0.052612, + -0.022110, 0.049103, 0.095276, 0.014572, 0.064819, 0.014671, 0.029800, 0.066284, + -0.383301, 0.071838, -0.207275, 0.099365, 0.640137, 0.393311, -0.334229, 0.275391, + -0.013977, -0.025269, -0.007065, -0.033478, -0.017349, 0.026764, 0.005192, 0.093384, + 0.014313, 0.018906, 0.006962, 0.094849, 0.005390, 0.101624, -0.041199, 0.026245, + 0.027588, 0.062408, 0.033356, -0.010826, 0.067993, -0.054199, 0.076416, 0.023315, + -0.002886, -0.112061, -0.041473, -0.012703, 0.016022, 0.010506, -0.021362, -0.037750, + 0.062927, 0.061920, 0.038177, -0.037201, -0.011620, 0.014015, -0.062164, -0.045441, + -0.063416, -0.040100, 0.035950, 0.045563, -0.017227, -0.060547, -0.017593, 0.111877, + 0.121521, 0.073853, 0.023331, -0.012428, 0.018478, -0.010948, 0.030716, 0.043427, + 0.003117, -0.069092, 0.038361, -0.053497, 0.039154, -0.085754, 0.012642, -0.051208, + 0.022934, 0.127197, 0.117920, 0.074036, 0.083313, -0.061951, 0.079224, 0.091248, + 0.009132, 0.069946, 0.123474, 0.130127, 0.118835, 0.020874, -0.045380, -0.000111, + 0.111206, 0.054688, 0.008995, 0.085693, 0.005562, 0.103088, -0.034698, 0.119934, + -0.067200, 0.065430, -0.021942, 0.089783, 0.033112, -0.025467, 0.040161, -0.052155, + -0.048920, 0.031250, 0.112549, 0.122192, 0.126587, 0.180908, 0.194946, 0.121704, + 0.217529, 0.224243, 0.269287, 0.222656, 0.288086, 0.035492, 0.066711, -0.046600, + 0.085144, 0.013855, -0.065979, -0.083252, -0.058289, 0.104126, 0.013702, -0.018188, + 0.036591, 0.099854, 0.056061, 0.151855, 0.062134, 0.133789, 0.084045, 0.095825, + 0.036987, 0.022308, 0.070923, 0.031036, 0.101868, 0.062347, 0.141235, 0.066650 +); + +@group(0) @binding(0) var static_features: texture_2d<u32>; +@group(0) @binding(1) var layer_input: texture_2d<u32>; +@group(0) @binding(2) var output_tex: texture_storage_2d<rgba32uint, write>; + +fn unpack_static_features(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(static_features, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn unpack_layer_channels(coord: vec2<i32>) -> array<f32, 8> { + let packed = textureLoad(layer_input, coord, 0); + let v0 = unpack2x16float(packed.x); + let v1 = unpack2x16float(packed.y); + let v2 = unpack2x16float(packed.z); + let v3 = unpack2x16float(packed.w); + return array<f32, 8>(v0.x, v0.y, v1.x, v1.y, v2.x, v2.y, v3.x, v3.y); +} + +fn pack_channels(values: array<f32, 8>) -> vec4<u32> { + return vec4<u32>( + pack2x16float(vec2<f32>(values[0], values[1])), + pack2x16float(vec2<f32>(values[2], values[3])), + pack2x16float(vec2<f32>(values[4], values[5])), + pack2x16float(vec2<f32>(values[6], values[7])) + ); +} + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) id: vec3<u32>) { + let coord = vec2<i32>(id.xy); + let dims = textureDimensions(static_features); + + if (coord.x >= i32(dims.x) || coord.y >= i32(dims.y)) { + return; + } + + // Load static features (always available) + let static_feat = unpack_static_features(coord); + + // Convolution + var output: array<f32, OUT_CHANNELS>; + for (var c: u32 = 0u; c < OUT_CHANNELS; c++) { + var sum: f32 = 0.0; + + for (var ky: i32 = -KERNEL_RADIUS; ky <= KERNEL_RADIUS; ky++) { + for (var kx: i32 = -KERNEL_RADIUS; kx <= KERNEL_RADIUS; kx++) { + let sample_coord = coord + vec2<i32>(kx, ky); + + // Border handling (clamp) + let clamped = vec2<i32>( + clamp(sample_coord.x, 0, i32(dims.x) - 1), + clamp(sample_coord.y, 0, i32(dims.y) - 1) + ); + + // Load input features + let static_local = unpack_static_features(clamped); + let layer_local = unpack_layer_channels(clamped); + + // Weight index calculation + let ky_idx = u32(ky + KERNEL_RADIUS); + let kx_idx = u32(kx + KERNEL_RADIUS); + let spatial_idx = ky_idx * KERNEL_SIZE + kx_idx; + + // Accumulate: static features (8D) + for (var i: u32 = 0u; i < 8u; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + i * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * static_local[i]; + } + + // Accumulate: layer input channels (if layer_idx > 0) + let prev_channels = IN_CHANNELS - 8u; + for (var i: u32 = 0u; i < prev_channels; i++) { + let w_idx = c * IN_CHANNELS * KERNEL_SIZE * KERNEL_SIZE + + (8u + i) * KERNEL_SIZE * KERNEL_SIZE + spatial_idx; + sum += weights[w_idx] * layer_local[i]; + } + } + } + + output[c] = clamp(sum, 0.0, 1.0); // Sigmoid approximation + } + + // Pack and store + textureStore(output_tex, coord, pack_channels(output)); +} |
