diff options
Diffstat (limited to 'src/tests/gpu/test_cnn_v3_parity.cc')
| -rw-r--r-- | src/tests/gpu/test_cnn_v3_parity.cc | 93 |
1 files changed, 82 insertions, 11 deletions
diff --git a/src/tests/gpu/test_cnn_v3_parity.cc b/src/tests/gpu/test_cnn_v3_parity.cc index 1a7f169..4fada5d 100644 --- a/src/tests/gpu/test_cnn_v3_parity.cc +++ b/src/tests/gpu/test_cnn_v3_parity.cc @@ -103,6 +103,76 @@ static std::vector<float> readback_rgba16float(WGPUDevice device, } // --------------------------------------------------------------------------- +// Raw RGBA32Uint readback → flat array of f32 (8 channels via unpack2x16float) +// --------------------------------------------------------------------------- + +static std::vector<float> readback_rgba32uint_8ch(WGPUDevice device, + WGPUQueue queue, + WGPUTexture tex, + int W, int H) { + const uint32_t bytes_per_px = 16; // 4 × u32 + const uint32_t unaligned_bpr = (uint32_t)(W * bytes_per_px); + const uint32_t aligned_bpr = ((unaligned_bpr + 255u) / 256u) * 256u; + const size_t buf_size = aligned_bpr * (size_t)H; + + WGPUBufferDescriptor bd = {}; + bd.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead; + bd.size = buf_size; + WGPUBuffer staging = wgpuDeviceCreateBuffer(device, &bd); + + WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, nullptr); + WGPUTexelCopyTextureInfo src = {}; + src.texture = tex; + WGPUTexelCopyBufferInfo dst = {}; + dst.buffer = staging; + dst.layout.bytesPerRow = aligned_bpr; + dst.layout.rowsPerImage = (uint32_t)H; + WGPUExtent3D extent = { (uint32_t)W, (uint32_t)H, 1 }; + wgpuCommandEncoderCopyTextureToBuffer(enc, &src, &dst, &extent); + WGPUCommandBuffer cmds = wgpuCommandEncoderFinish(enc, nullptr); + wgpuQueueSubmit(queue, 1, &cmds); + wgpuCommandBufferRelease(cmds); + wgpuCommandEncoderRelease(enc); + wgpuDevicePoll(device, true, nullptr); + + MapState ms = {}; + WGPUBufferMapCallbackInfo mi = {}; + mi.mode = WGPUCallbackMode_AllowProcessEvents; + mi.callback = [](WGPUMapAsyncStatus s, WGPUStringView, void* u, void*) { + auto* st = (MapState*)u; + st->status = s; st->done = true; + }; + mi.userdata1 = &ms; + wgpuBufferMapAsync(staging, WGPUMapMode_Read, 0, buf_size, mi); + for (int i = 0; i < 100 && !ms.done; ++i) + wgpuDevicePoll(device, true, nullptr); + + std::vector<float> result(W * H * 8, 0.0f); + if (ms.done && ms.status == WGPUMapAsyncStatus_Success) { + const uint8_t* mapped = (const uint8_t*)wgpuBufferGetConstMappedRange( + staging, 0, buf_size); + if (mapped) { + for (int y = 0; y < H; ++y) { + const uint32_t* row = + (const uint32_t*)(mapped + (size_t)y * aligned_bpr); + for (int x = 0; x < W; ++x) { + for (int j = 0; j < 4; ++j) { + uint32_t packed = row[x * 4 + j]; + result[(y * W + x) * 8 + j * 2 + 0] = + fp16_bits_to_f32((uint16_t)(packed & 0xFFFFu)); + result[(y * W + x) * 8 + j * 2 + 1] = + fp16_bits_to_f32((uint16_t)(packed >> 16)); + } + } + } + } + } + wgpuBufferUnmap(staging); + wgpuBufferRelease(staging); + return result; +} + +// --------------------------------------------------------------------------- // Helper: create rgba32uint texture with TextureBinding | CopyDst // --------------------------------------------------------------------------- @@ -190,8 +260,8 @@ static std::vector<float> run_cnn_v3(WebGPUTestFixture& fixture, effect.upload_weights(ctx.queue, weights_u32, weights_bytes); } else { // Explicitly zero weights to override any asset-loaded defaults. - // kWeightsBufBytes = ((2476+1)/2)*4 = 4952 - const uint32_t zero_size = ((2476u + 1u) / 2u) * 4u; + // kWeightsBufBytes = ((7828+1)/2)*4 = 15660 + const uint32_t zero_size = ((7828u + 1u) / 2u) * 4u; std::vector<uint8_t> zeros(zero_size, 0); effect.upload_weights(ctx.queue, zeros.data(), zero_size); } @@ -215,13 +285,14 @@ static std::vector<float> run_cnn_v3(WebGPUTestFixture& fixture, // Optional: read back intermediate layers if (enc0_out) { + // enc0 is rgba32uint, 8ch (pack2x16float), full-res WGPUTexture enc0_tex = registry.get_texture("cnn3_out_enc0"); - *enc0_out = readback_rgba16float(ctx.device, ctx.queue, enc0_tex, W, H); + *enc0_out = readback_rgba32uint_8ch(ctx.device, ctx.queue, enc0_tex, W, H); } if (dec1_out) { + // dec1 is rgba32uint, 8ch (pack2x16float), half-res WGPUTexture dec1_tex = registry.get_texture("cnn3_out_dec1"); - // dec1 is rgba16float written at half-res (W/2, H/2) — read only valid region - *dec1_out = readback_rgba16float(ctx.device, ctx.queue, dec1_tex, W / 2, H / 2); + *dec1_out = readback_rgba32uint_8ch(ctx.device, ctx.queue, dec1_tex, W / 2, H / 2); } // Cleanup @@ -298,18 +369,18 @@ static int test_random_weights() { kCnnV3TestWeightsU32, weights_bytes, &enc0_pixels, &dec1_pixels); - // Check enc0 layer first + // Check enc0 layer first (8ch, rgba32uint) const float tol = 1.0f / 255.0f; float enc0_max_err = 0.0f; int enc0_worst = -1; - for (int i = 0; i < W * H * 4; ++i) { + for (int i = 0; i < W * H * 8; ++i) { float ref = fp16_bits_to_f32(kCnnV3ExpectedEnc0U16[i]); float err = fabsf(enc0_pixels[i] - ref); if (err > enc0_max_err) { enc0_max_err = err; enc0_worst = i; } } bool enc0_ok = (enc0_max_err <= tol); if (!enc0_ok) { - int px = enc0_worst / 4, ch = enc0_worst % 4; + int px = enc0_worst / 8, ch = enc0_worst % 8; fprintf(stderr, " ✗ enc0 mismatch: max_err=%.5f > %.5f at px=%d ch=%d" " gpu=%.5f ref=%.5f\n", enc0_max_err, tol, px, ch, @@ -319,10 +390,10 @@ static int test_random_weights() { fprintf(stdout, " ✓ enc0: max_err=%.2e OK\n", enc0_max_err); } - // Check dec1 layer (half-res: W/2 x H/2 x 4) + // Check dec1 layer (8ch, rgba32uint, half-res: W/2 x H/2 x 8) float dec1_max_err = 0.0f; int dec1_worst = -1; - int dec1_n = (W / 2) * (H / 2) * 4; + int dec1_n = (W / 2) * (H / 2) * 8; for (int i = 0; i < dec1_n; ++i) { float ref = fp16_bits_to_f32(kCnnV3ExpectedDec1U16[i]); float err = fabsf(dec1_pixels[i] - ref); @@ -330,7 +401,7 @@ static int test_random_weights() { } bool dec1_ok = (dec1_max_err <= tol); if (!dec1_ok) { - int px = dec1_worst / 4, ch = dec1_worst % 4; + int px = dec1_worst / 8, ch = dec1_worst % 8; fprintf(stderr, " ✗ dec1 mismatch: max_err=%.5f > %.5f at px=%d ch=%d" " gpu=%.5f ref=%.5f\n", dec1_max_err, tol, px, ch, |
