// This file is part of the 64k demo project. // Tests for the ANS (rANS) entropy coder in util/ans.h. // Encoder is enabled in tests via ANS_ENABLE_ENCODER. #include "util/ans.h" #include #include #include #include #include #include #include namespace { bool RoundtripCheck(const std::vector& input, const uint32_t* initial_counts, const char* label) { std::vector compressed; if (!ans::Encode(input.data(), input.size(), &compressed, initial_counts)) { fprintf(stderr, "[%s] Encode failed\n", label); return false; } std::vector decoded(input.size()); size_t decoded_size = 0; if (!ans::Decode(compressed.data(), compressed.size(), decoded.data(), decoded.size(), &decoded_size, initial_counts)) { fprintf(stderr, "[%s] Decode failed\n", label); return false; } if (decoded_size != input.size() || std::memcmp(decoded.data(), input.data(), input.size()) != 0) { fprintf(stderr, "[%s] payload mismatch\n", label); return false; } const double ratio = input.empty() ? 0.0 : (double)compressed.size() / (double)input.size(); fprintf(stderr, "[%s] OK: %zu -> %zu bytes (%.3f x)\n", label, input.size(), compressed.size(), ratio); return true; } // Covers: empty / single byte / single-symbol run / all-zeros / random uniform // / random skewed / repeated ASCII. Each is a roundtrip with default (uniform) // initial counts. void TestRoundtripVariants() { std::mt19937 rng_uniform(12345); std::vector random_uniform(64 * 1024); for (auto& b : random_uniform) b = (uint8_t)(rng_uniform() & 0xff); std::mt19937 rng_skewed(67890); std::vector random_skewed(32 * 1024); for (auto& b : random_skewed) { // 90% byte 'A', 10% other ASCII letters. const uint32_t r = rng_skewed(); b = (r % 10 == 0) ? (uint8_t)('B' + (r >> 8) % 25) : (uint8_t)'A'; } const char* ascii = "@group(0) @binding(0) var smplr: sampler;\n" "@group(0) @binding(1) var tex: texture_2d;\n" "@fragment fn fs_main(@builtin(position) p: vec4f) -> @location(0) " "vec4f {\n" " let uv = p.xy / vec2f(1280.0, 720.0);\n" " return textureSample(tex, smplr, uv);\n" "}\n"; std::vector ascii_block; for (int i = 0; i < 50; ++i) { // cross chunk boundary ascii_block.insert(ascii_block.end(), ascii, ascii + std::strlen(ascii)); } struct Case { const char* label; std::vector data; }; const Case cases[] = { {"empty", {}}, {"single-byte", {0x42}}, {"single-symbol-run", std::vector(4096, 'A')}, {"all-zeros", std::vector(8192, 0)}, {"random-uniform", random_uniform}, {"random-skewed", random_skewed}, {"ascii-shader-text", ascii_block}, }; for (const Case& c : cases) { assert(RoundtripCheck(c.data, nullptr, c.label)); } } // Seeded initial counts should compress a matching corpus at least as well // as uniform init, and roundtrip identically. void TestSeededInitialCounts() { std::vector corpus; const char* sample = "fn main() { let x = vec3f(1.0, 2.0, 3.0); }\n"; for (int i = 0; i < 200; ++i) { corpus.insert(corpus.end(), sample, sample + std::strlen(sample)); } uint32_t hist[256] = {}; ans::Histogram(corpus.data(), corpus.size(), hist); std::vector seeded_out, uniform_out; assert(ans::Encode(corpus.data(), corpus.size(), &seeded_out, hist)); assert(ans::Encode(corpus.data(), corpus.size(), &uniform_out, nullptr)); std::vector decoded(corpus.size()); size_t decoded_size = 0; assert(ans::Decode(seeded_out.data(), seeded_out.size(), decoded.data(), decoded.size(), &decoded_size, hist)); assert(decoded_size == corpus.size()); assert(std::memcmp(decoded.data(), corpus.data(), corpus.size()) == 0); fprintf(stderr, "[seeded-vs-uniform] seeded=%zu uniform=%zu raw=%zu\n", seeded_out.size(), uniform_out.size(), corpus.size()); assert(seeded_out.size() <= uniform_out.size()); } // The chunk-end state check must reject: a mismatched model (decoder uses a // different initial distribution), a single bit-flip in the payload, and a // truncated stream. void TestRejection() { std::mt19937 rng(1); // 1) Mismatched models. { std::vector v(4096); for (auto& b : v) b = (uint8_t)('a' + (rng() % 26)); uint32_t hist[256] = {}; ans::Histogram(v.data(), v.size(), hist); std::vector encoded; assert(ans::Encode(v.data(), v.size(), &encoded, hist)); std::vector decoded(v.size()); size_t decoded_size = 0; assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &decoded_size, nullptr)); fprintf(stderr, "[rejection] mismatched-counts OK\n"); } // 2) Corruption. { std::vector v(2048); for (auto& b : v) b = (uint8_t)(rng() & 0xff); std::vector encoded; assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); encoded[encoded.size() / 2] ^= 0x55; // flip a payload byte std::vector decoded(v.size()); size_t decoded_size = 0; assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &decoded_size, nullptr)); fprintf(stderr, "[rejection] corruption OK\n"); } // 3) Truncation. { std::vector v(4096); for (size_t i = 0; i < v.size(); ++i) v[i] = (uint8_t)i; std::vector encoded; assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); encoded.resize(encoded.size() - 8); std::vector decoded(v.size()); size_t decoded_size = 0; assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &decoded_size, nullptr)); fprintf(stderr, "[rejection] truncation OK\n"); } } void TestPeekSize() { std::vector v(1234, 'Q'); std::vector encoded; assert(ans::Encode(v.data(), v.size(), &encoded, nullptr)); assert(ans::PeekUncompressedSize(encoded.data(), encoded.size()) == v.size()); } } // namespace int main() { TestRoundtripVariants(); TestSeededInitialCounts(); TestRejection(); TestPeekSize(); fprintf(stderr, "All ANS tests passed.\n"); return 0; }