From b04816a400703ac6c364efb70ae84930d79ccb12 Mon Sep 17 00:00:00 2001 From: skal Date: Fri, 13 Feb 2026 16:12:24 +0100 Subject: CNN v2: Fix activation function mismatch between training and inference Layer 0 now uses clamp [0,1] in both training and inference (was using ReLU in shaders). - index.html: Add is_layer_0 flag to LayerParams, handle Layer 0 separately - export_cnn_v2_shader.py: Generate correct activation for Layer 0 Co-Authored-By: Claude Sonnet 4.5 --- tools/cnn_v2_test/index.html | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'tools/cnn_v2_test') diff --git a/tools/cnn_v2_test/index.html b/tools/cnn_v2_test/index.html index 79c54b7..fc93223 100644 --- a/tools/cnn_v2_test/index.html +++ b/tools/cnn_v2_test/index.html @@ -397,6 +397,7 @@ struct LayerParams { weight_offset: u32, is_output_layer: u32, blend_amount: f32, + is_layer_0: u32, } @group(0) @binding(0) var static_features: texture_2d; @@ -490,8 +491,10 @@ fn main(@builtin(global_invocation_id) id: vec3) { if (is_output) { output[c] = clamp(sum, 0.0, 1.0); + } else if (params.is_layer_0 != 0u) { + output[c] = clamp(sum, 0.0, 1.0); // Layer 0: clamp [0,1] } else { - output[c] = max(0.0, sum); + output[c] = max(0.0, sum); // Middle layers: ReLU } } @@ -1105,18 +1108,19 @@ class CNNTester { const headerOffsetU32 = 4 + this.weights.layers.length * 5; // Header + layer info in u32 const absoluteWeightOffset = headerOffsetU32 * 2 + layer.weightOffset; // Convert to f16 units - const paramsData = new Uint32Array(6); + const paramsData = new Uint32Array(7); paramsData[0] = layer.kernelSize; paramsData[1] = layer.inChannels; paramsData[2] = layer.outChannels; paramsData[3] = absoluteWeightOffset; // Use absolute offset paramsData[4] = isOutput ? 1 : 0; + paramsData[6] = (i === 0) ? 1 : 0; // is_layer_0 flag const paramsView = new Float32Array(paramsData.buffer); paramsView[5] = this.blendAmount; const paramsBuffer = this.device.createBuffer({ - size: 24, + size: 28, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); this.device.queue.writeBuffer(paramsBuffer, 0, paramsData); -- cgit v1.2.3