summaryrefslogtreecommitdiff
path: root/src/tests/audio/test_wav_roundtrip.cc
blob: 6294d6d8701e52ab4315b076667cfc995bd359a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// Tests the wav->spec->wav roundtrip SNR.
// Generates a sine wave, runs OLA-DCT analysis then IMDCT-OLA synthesis,
// and asserts the reconstruction SNR exceeds the threshold.

#include "audio/dct.h"
#include "audio/window.h"
#include <assert.h>
#include <cmath>
#include <cstdio>
#include <vector>

static const int SAMPLE_RATE = 32000;
static const float PI = 3.14159265358979323846f;

// Replicate analyze_audio OLA pass (Hann + FDCT, hop = OLA_HOP_SIZE)
static std::vector<float> ola_analyze(const std::vector<float>& pcm) {
  float win[DCT_SIZE];
  hann_window_512(win);

  const int hop = OLA_HOP_SIZE;
  const int n_pcm = (int)pcm.size();
  const int num_frames = (n_pcm > DCT_SIZE) ? (n_pcm - DCT_SIZE) / hop + 1 : 1;

  std::vector<float> spec(num_frames * DCT_SIZE);
  float chunk[DCT_SIZE];

  for (int f = 0; f < num_frames; ++f) {
    const int start = f * hop;
    const int avail = (start + DCT_SIZE <= n_pcm) ? DCT_SIZE : n_pcm - start;
    for (int i = 0; i < avail; ++i) chunk[i] = pcm[start + i] * win[i];
    for (int i = avail; i < DCT_SIZE; ++i) chunk[i] = 0.0f;

    fdct_512(chunk, spec.data() + f * DCT_SIZE);
  }
  return spec;
}

// IDCT + OLA synthesis (no synthesis window) matching decode_to_wav.
// Analysis used Hann; since Hann satisfies w[n]+w[n+H]=1 at 50% overlap,
// skipping the synthesis window gives perfect reconstruction.
static std::vector<float> ola_decode(const std::vector<float>& spec,
                                     int num_frames) {
  std::vector<float> pcm(num_frames * OLA_HOP_SIZE + OLA_OVERLAP, 0.0f);
  float overlap[OLA_OVERLAP] = {};
  float tmp[DCT_SIZE];

  for (int f = 0; f < num_frames; ++f) {
    idct_512(spec.data() + f * DCT_SIZE, tmp);
    for (int j = 0; j < OLA_HOP_SIZE; ++j)
      pcm[f * OLA_HOP_SIZE + j] = tmp[j] + overlap[j];
    for (int j = 0; j < OLA_OVERLAP; ++j)
      overlap[j] = tmp[OLA_HOP_SIZE + j];
  }
  pcm.resize(num_frames * OLA_HOP_SIZE);
  return pcm;
}

static float compute_snr_db(const std::vector<float>& ref,
                             const std::vector<float>& out,
                             int skip_samples) {
  const int n = (int)std::min(ref.size(), out.size());
  double sig = 0.0, noise = 0.0;
  for (int i = skip_samples; i < n; ++i) {
    sig += (double)ref[i] * ref[i];
    double e = ref[i] - out[i];
    noise += e * e;
  }
  if (noise < 1e-30) return 999.0f;
  return 10.0f * (float)log10(sig / noise);
}

int main() {
  printf("Running WAV roundtrip test...\n");

  // 1-second 440 Hz sine at 32 kHz
  const int n_samples = SAMPLE_RATE;
  std::vector<float> input(n_samples);
  for (int i = 0; i < n_samples; ++i)
    input[i] = 0.5f * sinf(2.0f * PI * 440.0f * i / SAMPLE_RATE);

  // Analyze
  std::vector<float> spec = ola_analyze(input);
  const int num_frames = (int)(spec.size() / DCT_SIZE);

  // Decode with IDCT-OLA (no synthesis window)
  std::vector<float> output = ola_decode(spec, num_frames);

  // SNR — skip first DCT_SIZE samples (ramp-up transient)
  const float snr = compute_snr_db(input, output, DCT_SIZE);
  printf("Roundtrip SNR: %.1f dB  (frames=%d, out_samples=%zu)\n",
         snr, num_frames, output.size());

  const float MIN_SNR_DB = 30.0f;
  if (snr < MIN_SNR_DB) {
    printf("FAIL: SNR %.1f dB < %.0f dB threshold\n", snr, MIN_SNR_DB);
    return 1;
  }

  printf("PASS\n");
  return 0;
}