summaryrefslogtreecommitdiff
path: root/src/tests
diff options
context:
space:
mode:
Diffstat (limited to 'src/tests')
-rw-r--r--src/tests/util/test_ans.cc189
1 files changed, 189 insertions, 0 deletions
diff --git a/src/tests/util/test_ans.cc b/src/tests/util/test_ans.cc
new file mode 100644
index 0000000..108c46d
--- /dev/null
+++ b/src/tests/util/test_ans.cc
@@ -0,0 +1,189 @@
+// This file is part of the 64k demo project.
+// Tests for the ANS (rANS) entropy coder in util/ans.h.
+// Encoder is enabled in tests via ANS_ENABLE_ENCODER.
+
+#include "util/ans.h"
+
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+#include <random>
+#include <string>
+#include <vector>
+
+namespace {
+
+bool RoundtripCheck(const std::vector<uint8_t>& input,
+ const uint32_t* initial_counts,
+ const char* label) {
+ std::vector<uint8_t> compressed;
+ if (!ans::Encode(input.data(), input.size(), &compressed, initial_counts)) {
+ fprintf(stderr, "[%s] Encode failed\n", label);
+ return false;
+ }
+ std::vector<uint8_t> decoded(input.size());
+ size_t decoded_size = 0;
+ if (!ans::Decode(compressed.data(), compressed.size(), decoded.data(),
+ decoded.size(), &decoded_size, initial_counts)) {
+ fprintf(stderr, "[%s] Decode failed\n", label);
+ return false;
+ }
+ if (decoded_size != input.size() ||
+ std::memcmp(decoded.data(), input.data(), input.size()) != 0) {
+ fprintf(stderr, "[%s] payload mismatch\n", label);
+ return false;
+ }
+ const double ratio =
+ input.empty() ? 0.0 : (double)compressed.size() / (double)input.size();
+ fprintf(stderr, "[%s] OK: %zu -> %zu bytes (%.3f x)\n", label, input.size(),
+ compressed.size(), ratio);
+ return true;
+}
+
+// Covers: empty / single byte / single-symbol run / all-zeros / random uniform
+// / random skewed / repeated ASCII. Each is a roundtrip with default (uniform)
+// initial counts.
+void TestRoundtripVariants() {
+ std::mt19937 rng_uniform(12345);
+ std::vector<uint8_t> random_uniform(64 * 1024);
+ for (auto& b : random_uniform) b = (uint8_t)(rng_uniform() & 0xff);
+
+ std::mt19937 rng_skewed(67890);
+ std::vector<uint8_t> random_skewed(32 * 1024);
+ for (auto& b : random_skewed) {
+ // 90% byte 'A', 10% other ASCII letters.
+ const uint32_t r = rng_skewed();
+ b = (r % 10 == 0) ? (uint8_t)('B' + (r >> 8) % 25) : (uint8_t)'A';
+ }
+
+ const char* ascii =
+ "@group(0) @binding(0) var smplr: sampler;\n"
+ "@group(0) @binding(1) var tex: texture_2d<f32>;\n"
+ "@fragment fn fs_main(@builtin(position) p: vec4f) -> @location(0) "
+ "vec4f {\n"
+ " let uv = p.xy / vec2f(1280.0, 720.0);\n"
+ " return textureSample(tex, smplr, uv);\n"
+ "}\n";
+ std::vector<uint8_t> ascii_block;
+ for (int i = 0; i < 50; ++i) { // cross chunk boundary
+ ascii_block.insert(ascii_block.end(), ascii, ascii + std::strlen(ascii));
+ }
+
+ struct Case {
+ const char* label;
+ std::vector<uint8_t> data;
+ };
+ const Case cases[] = {
+ {"empty", {}},
+ {"single-byte", {0x42}},
+ {"single-symbol-run", std::vector<uint8_t>(4096, 'A')},
+ {"all-zeros", std::vector<uint8_t>(8192, 0)},
+ {"random-uniform", random_uniform},
+ {"random-skewed", random_skewed},
+ {"ascii-shader-text", ascii_block},
+ };
+ for (const Case& c : cases) {
+ assert(RoundtripCheck(c.data, nullptr, c.label));
+ }
+}
+
+// Seeded initial counts should compress a matching corpus at least as well
+// as uniform init, and roundtrip identically.
+void TestSeededInitialCounts() {
+ std::vector<uint8_t> corpus;
+ const char* sample = "fn main() { let x = vec3f(1.0, 2.0, 3.0); }\n";
+ for (int i = 0; i < 200; ++i) {
+ corpus.insert(corpus.end(), sample, sample + std::strlen(sample));
+ }
+ uint32_t hist[256] = {};
+ ans::Histogram(corpus.data(), corpus.size(), hist);
+
+ std::vector<uint8_t> seeded_out, uniform_out;
+ assert(ans::Encode(corpus.data(), corpus.size(), &seeded_out, hist));
+ assert(ans::Encode(corpus.data(), corpus.size(), &uniform_out, nullptr));
+
+ std::vector<uint8_t> decoded(corpus.size());
+ size_t decoded_size = 0;
+ assert(ans::Decode(seeded_out.data(), seeded_out.size(), decoded.data(),
+ decoded.size(), &decoded_size, hist));
+ assert(decoded_size == corpus.size());
+ assert(std::memcmp(decoded.data(), corpus.data(), corpus.size()) == 0);
+
+ fprintf(stderr, "[seeded-vs-uniform] seeded=%zu uniform=%zu raw=%zu\n",
+ seeded_out.size(), uniform_out.size(), corpus.size());
+ assert(seeded_out.size() <= uniform_out.size());
+}
+
+// The chunk-end state check must reject: a mismatched model (decoder uses a
+// different initial distribution), a single bit-flip in the payload, and a
+// truncated stream.
+void TestRejection() {
+ std::mt19937 rng(1);
+
+ // 1) Mismatched models.
+ {
+ std::vector<uint8_t> v(4096);
+ for (auto& b : v) b = (uint8_t)('a' + (rng() % 26));
+ uint32_t hist[256] = {};
+ ans::Histogram(v.data(), v.size(), hist);
+
+ std::vector<uint8_t> encoded;
+ assert(ans::Encode(v.data(), v.size(), &encoded, hist));
+
+ std::vector<uint8_t> decoded(v.size());
+ size_t decoded_size = 0;
+ assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(),
+ decoded.size(), &decoded_size, nullptr));
+ fprintf(stderr, "[rejection] mismatched-counts OK\n");
+ }
+
+ // 2) Corruption.
+ {
+ std::vector<uint8_t> v(2048);
+ for (auto& b : v) b = (uint8_t)(rng() & 0xff);
+ std::vector<uint8_t> encoded;
+ assert(ans::Encode(v.data(), v.size(), &encoded, nullptr));
+ encoded[encoded.size() / 2] ^= 0x55; // flip a payload byte
+
+ std::vector<uint8_t> decoded(v.size());
+ size_t decoded_size = 0;
+ assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(),
+ decoded.size(), &decoded_size, nullptr));
+ fprintf(stderr, "[rejection] corruption OK\n");
+ }
+
+ // 3) Truncation.
+ {
+ std::vector<uint8_t> v(4096);
+ for (size_t i = 0; i < v.size(); ++i) v[i] = (uint8_t)i;
+ std::vector<uint8_t> encoded;
+ assert(ans::Encode(v.data(), v.size(), &encoded, nullptr));
+ encoded.resize(encoded.size() - 8);
+
+ std::vector<uint8_t> decoded(v.size());
+ size_t decoded_size = 0;
+ assert(!ans::Decode(encoded.data(), encoded.size(), decoded.data(),
+ decoded.size(), &decoded_size, nullptr));
+ fprintf(stderr, "[rejection] truncation OK\n");
+ }
+}
+
+void TestPeekSize() {
+ std::vector<uint8_t> v(1234, 'Q');
+ std::vector<uint8_t> encoded;
+ assert(ans::Encode(v.data(), v.size(), &encoded, nullptr));
+ assert(ans::PeekUncompressedSize(encoded.data(), encoded.size()) ==
+ v.size());
+}
+
+} // namespace
+
+int main() {
+ TestRoundtripVariants();
+ TestSeededInitialCounts();
+ TestRejection();
+ TestPeekSize();
+ fprintf(stderr, "All ANS tests passed.\n");
+ return 0;
+}