diff options
Diffstat (limited to 'src/tests')
| -rw-r--r-- | src/tests/util/test_ans.cc | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/src/tests/util/test_ans.cc b/src/tests/util/test_ans.cc new file mode 100644 index 0000000..108c46d --- /dev/null +++ b/src/tests/util/test_ans.cc @@ -0,0 +1,189 @@ +// This file is part of the 64k demo project. +// Tests for the ANS (rANS) entropy coder in util/ans.h. +// Encoder is enabled in tests via ANS_ENABLE_ENCODER. + +#include "util/ans.h" + +#include <cassert> +#include <cstdint> +#include <cstdio> +#include <cstring> +#include <random> +#include <string> +#include <vector> + +namespace { + +bool RoundtripCheck(const std::vector<uint8_t>& input, + const uint32_t* initial_counts, + const char* label) { + std::vector<uint8_t> compressed; + if (!ans::Encode(input.data(), input.size(), &compressed, initial_counts)) { + fprintf(stderr, "[%s] Encode failed\n", label); + return false; + } + std::vector<uint8_t> decoded(input.size()); + size_t decoded_size = 0; + if (!ans::Decode(compressed.data(), compressed.size(), decoded.data(), + decoded.size(), &decoded_size, initial_counts)) { + fprintf(stderr, "[%s] Decode failed\n", label); + return false; + } + if (decoded_size != input.size() || + std::memcmp(decoded.data(), input.data(), input.size()) != 0) { + fprintf(stderr, "[%s] payload mismatch\n", label); + return false; + } + const double ratio = + input.empty() ? 0.0 : (double)compressed.size() / (double)input.size(); + fprintf(stderr, "[%s] OK: %zu -> %zu bytes (%.3f x)\n", label, input.size(), + compressed.size(), ratio); + return true; +} + +// Covers: empty / single byte / single-symbol run / all-zeros / random uniform +// / random skewed / repeated ASCII. Each is a roundtrip with default (uniform) +// initial counts. +void TestRoundtripVariants() { + std::mt19937 rng_uniform(12345); + std::vector<uint8_t> random_uniform(64 * 1024); + for (auto& b : random_uniform) b = (uint8_t)(rng_uniform() & 0xff); + + std::mt19937 rng_skewed(67890); + std::vector<uint8_t> random_skewed(32 * 1024); + for (auto& b : random_skewed) { + // 90% byte 'A', 10% other ASCII letters. + const uint32_t r = rng_skewed(); + b = (r % 10 == 0) ? (uint8_t)('B' + (r >> 8) % 25) : (uint8_t)'A'; + } + + const char* ascii = + "@group(0) @binding(0) var smplr: sampler;\n" + "@group(0) @binding(1) var tex: texture_2d<f32>;\n" + "@fragment fn fs_main(@builtin(position) p: vec4f) -> @location(0) " + "vec4f {\n" + " let uv = p.xy / vec2f(1280.0, 720.0);\n" + " return textureSample(tex, smplr, uv);\n" + "}\n"; + std::vector<uint8_t> ascii_block; + for (int i = 0; i < 50; ++i) { // cross chunk boundary + ascii_block.insert(ascii_block.end(), ascii, ascii + std::strlen(ascii)); + } + + struct Case { + const char* label; + std::vector<uint8_t> data; + }; + const Case cases[] = { + {"empty", {}}, + {"single-byte", {0x42}}, + {"single-symbol-run", std::vector<uint8_t>(4096, 'A')}, + {"all-zeros", std::vector<uint8_t>(8192, 0)}, + {"random-uniform", random_uniform}, + {"random-skewed", random_skewed}, + {"ascii-shader-text", ascii_block}, + }; + for (const Case& c : cases) { + assert(RoundtripCheck(c.data, nullptr, c.label)); + } +} + +// Seeded initial counts should compress a matching corpus at least as well +// as uniform init, and roundtrip identically. +void TestSeededInitialCounts() { + std::vector<uint8_t> corpus; + const char* sample = "fn main() { let x = vec3f(1.0, 2.0, 3.0); }\n"; + for (int i = 0; i < 200; ++i) { + corpus.insert(corpus.end(), sample, sample + std::strlen(sample)); + } + uint32_t hist[256] = {}; + ans::Histogram(corpus.data(), corpus.size(), hist); + + std::vector<uint8_t> seeded_out, uniform_out; + assert(ans::Encode(corpus.data(), corpus.size(), &seeded_out, hist)); + assert(ans::Encode(corpus.data(), corpus.size(), &uniform_out, nullptr)); + + std::vector<uint8_t> decoded(corpus.size()); + size_t decoded_size = 0; + assert(ans::Decode(seeded_out.data(), seeded_out.size(), decoded.data(), + decoded.size(), &decoded_size, hist)); + assert(decoded_size == corpus.size()); + assert(std::memcmp(decoded.data(), corpus.data(), corpus.size()) == 0); + + fprintf(stderr, "[seeded-vs-uniform] seeded=%zu uniform=%zu raw=%zu\n", + seeded_out.size(), uniform_out.size(), corpus.size()); + assert(seeded_out.size() <= uniform_out.size()); +} + +// The chunk-end state check must reject: a mismatched model (decoder uses a +// different initial distribution), a single bit-flip in the payload, and a +// truncated stream. +void TestRejection() { + std::mt19937 rng(1); + + // 1) Mismatched models. + { + std::vector<uint8_t> v(4096); + for (auto& b : v) b = (uint8_t)('a' + (rng() % 26)); + uint32_t hist[256] = {}; + ans::Histogram(v.data(), v.size(), hist); + + std::vector<uint8_t> encoded; + assert(ans::Encode(v.data(), v.size(), &encoded, hist)); + + std::vector<uint8_t> decoded(v.size()); + size_t decoded_size = 0; + assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), + decoded.size(), &decoded_size, nullptr)); + fprintf(stderr, "[rejection] mismatched-counts OK\n"); + } + + // 2) Corruption. + { + std::vector<uint8_t> v(2048); + for (auto& b : v) b = (uint8_t)(rng() & 0xff); + std::vector<uint8_t> encoded; + assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); + encoded[encoded.size() / 2] ^= 0x55; // flip a payload byte + + std::vector<uint8_t> decoded(v.size()); + size_t decoded_size = 0; + assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), + decoded.size(), &decoded_size, nullptr)); + fprintf(stderr, "[rejection] corruption OK\n"); + } + + // 3) Truncation. + { + std::vector<uint8_t> v(4096); + for (size_t i = 0; i < v.size(); ++i) v[i] = (uint8_t)i; + std::vector<uint8_t> encoded; + assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); + encoded.resize(encoded.size() - 8); + + std::vector<uint8_t> decoded(v.size()); + size_t decoded_size = 0; + assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), + decoded.size(), &decoded_size, nullptr)); + fprintf(stderr, "[rejection] truncation OK\n"); + } +} + +void TestPeekSize() { + std::vector<uint8_t> v(1234, 'Q'); + std::vector<uint8_t> encoded; + assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); + assert(ans::PeekUncompressedSize(encoded.data(), encoded.size()) == + v.size()); +} + +} // namespace + +int main() { + TestRoundtripVariants(); + TestSeededInitialCounts(); + TestRejection(); + TestPeekSize(); + fprintf(stderr, "All ANS tests passed.\n"); + return 0; +} |
