summaryrefslogtreecommitdiff
path: root/tools/mq_editor/mq_extract.js
diff options
context:
space:
mode:
authorskal <pascal.massimino@gmail.com>2026-02-18 14:00:18 +0100
committerskal <pascal.massimino@gmail.com>2026-02-18 14:00:18 +0100
commita9a60dfd2df938ef1e3ecc0a06d3d50cc329ef30 (patch)
treed4bfa5597eda60db9c0ea3ad20c53c39e228819a /tools/mq_editor/mq_extract.js
parentc3c1011cb6bf9bca28736b89049d76875a031ebe (diff)
feat(mq_editor): Implement phase-coherent partial tracking
Implements a more robust partial tracking algorithm by using phase coherence to validate and link spectral peaks across frames. This significantly improves tracking accuracy, especially for crossing or closely-spaced partials. Key changes: - : The STFT cache now computes and stores the phase for each frequency bin alongside the magnitude. - : The peak detection logic now interpolates the phase for sub-bin accuracy. - : The core tracking algorithm was replaced with a phase-aware model. It predicts the expected phase for a partial in the next frame and uses a cost function combining both frequency and phase error to find the best match. This implementation follows the design outlined in the new document . handoff(Gemini): Phase prediction implemented. Further tuning of the cost function weights may be beneficial.
Diffstat (limited to 'tools/mq_editor/mq_extract.js')
-rw-r--r--tools/mq_editor/mq_extract.js107
1 files changed, 82 insertions, 25 deletions
diff --git a/tools/mq_editor/mq_extract.js b/tools/mq_editor/mq_extract.js
index 97fbb00..a530960 100644
--- a/tools/mq_editor/mq_extract.js
+++ b/tools/mq_editor/mq_extract.js
@@ -9,12 +9,13 @@ function extractPartials(params, stftCache) {
const frames = [];
for (let i = 0; i < numFrames; ++i) {
const cachedFrame = stftCache.getFrameAtIndex(i);
- const squaredAmp = stftCache.getSquaredAmplitude(cachedFrame.time);
- const peaks = detectPeaks(squaredAmp, fftSize, sampleRate, threshold, freqWeight, prominence);
+ const squaredAmp = cachedFrame.squaredAmplitude;
+ const phase = cachedFrame.phase;
+ const peaks = detectPeaks(squaredAmp, phase, fftSize, sampleRate, threshold, freqWeight, prominence);
frames.push({time: cachedFrame.time, peaks});
}
- const partials = trackPartials(frames);
+ const partials = trackPartials(frames, params);
// Second pass: extend partials leftward to recover onset frames
expandPartialsLeft(partials, frames);
@@ -27,10 +28,27 @@ function extractPartials(params, stftCache) {
return {partials, frames};
}
+// Helper to interpolate phase via quadratic formula on unwrapped neighbors.
+// This provides a more accurate phase estimate at the sub-bin peak location.
+function phaseInterp(p_minus, p_center, p_plus, p_frac) {
+ // unwrap neighbors relative to center
+ let dp_minus = p_minus - p_center;
+ while (dp_minus > Math.PI) dp_minus -= 2 * Math.PI;
+ while (dp_minus < -Math.PI) dp_minus += 2 * Math.PI;
+
+ let dp_plus = p_plus - p_center;
+ while (dp_plus > Math.PI) dp_plus -= 2 * Math.PI;
+ while (dp_plus < -Math.PI) dp_plus += 2 * Math.PI;
+
+ const p_interp = p_center + (dp_plus - dp_minus) * p_frac * 0.5 + (dp_plus + dp_minus) * p_frac * p_frac;
+ return p_interp;
+}
+
// Detect spectral peaks via local maxima + parabolic interpolation
// squaredAmp: pre-computed re*re+im*im per bin
+// phase: pre-computed atan2(im,re) per bin
// freqWeight: if true, weight by f before peak detection (f * Power(f))
-function detectPeaks(squaredAmp, fftSize, sampleRate, thresholdDB, freqWeight, prominenceDB = 0) {
+function detectPeaks(squaredAmp, phase, fftSize, sampleRate, thresholdDB, freqWeight, prominenceDB = 0) {
const mag = new Float32Array(fftSize / 2);
const binHz = sampleRate / fftSize;
for (let i = 0; i < fftSize / 2; ++i) {
@@ -62,23 +80,31 @@ function detectPeaks(squaredAmp, fftSize, sampleRate, thresholdDB, freqWeight, p
if (mag[i] - valley < prominenceDB) continue;
}
- // Parabolic interpolation for sub-bin accuracy
+ // Parabolic interpolation for sub-bin accuracy on frequency, amplitude, and phase
const alpha = mag[i-1];
const beta = mag[i];
const gamma = mag[i+1];
const p = 0.5 * (alpha - gamma) / (alpha - 2*beta + gamma);
+
+ const p_phase = phaseInterp(phase[i-1], phase[i], phase[i+1], p);
const freq = (i + p) * sampleRate / fftSize;
const ampDB = beta - 0.25 * (alpha - gamma) * p;
- peaks.push({freq, amp: Math.pow(10, ampDB / 20)});
+ peaks.push({freq, amp: Math.pow(10, ampDB / 20), phase: p_phase});
}
}
return peaks;
}
-// Track partials across frames (birth/death/continuation)
-function trackPartials(frames) {
+// Helper to compute shortest angle difference (e.g., between -pi and pi)
+function normalizeAngle(angle) {
+ return angle - 2 * Math.PI * Math.floor((angle + Math.PI) / (2 * Math.PI));
+}
+
+// Track partials across frames using phase coherence for robust matching.
+function trackPartials(frames, params) {
+ const { sampleRate, hopSize } = params;
const partials = [];
const activePartials = [];
const candidates = []; // pre-birth
@@ -89,22 +115,37 @@ function trackPartials(frames) {
const deathAge = 5; // frames without match before death
const minLength = 10; // frames required to keep partial
+ // Weight phase error heavily in cost function, scaled by frequency.
+ // This makes phase deviation more significant for high-frequency partials.
+ const phaseErrorWeight = 2.0;
+
for (const frame of frames) {
const matched = new Set();
- // Continue active partials
+ // --- Continue active partials ---
for (const partial of activePartials) {
const lastFreq = partial.freqs[partial.freqs.length - 1];
+ const lastPhase = partial.phases[partial.phases.length - 1];
const velocity = partial.velocity || 0;
- const predicted = lastFreq + velocity;
+ const predictedFreq = lastFreq + velocity;
- const tol = Math.max(lastFreq * trackingRatio, minTrackingHz);
- let bestIdx = -1, bestDist = Infinity;
+ // Predict phase for the current frame based on the last frame's frequency.
+ const phaseAdvance = 2 * Math.PI * lastFreq * hopSize / sampleRate;
+ const predictedPhase = lastPhase + phaseAdvance;
+
+ const tol = Math.max(predictedFreq * trackingRatio, minTrackingHz);
+ let bestIdx = -1, bestCost = Infinity;
+ // Find the peak in the new frame with the lowest cost (freq + phase error).
for (let i = 0; i < frame.peaks.length; ++i) {
if (matched.has(i)) continue;
- const dist = Math.abs(frame.peaks[i].freq - predicted);
- if (dist < tol && dist < bestDist) { bestDist = dist; bestIdx = i; }
+ const pk = frame.peaks[i];
+ const freqError = Math.abs(pk.freq - predictedFreq);
+ if (freqError > tol) continue;
+
+ const phaseError = Math.abs(normalizeAngle(pk.phase - predictedPhase));
+ const cost = freqError + phaseErrorWeight * phaseError * predictedFreq;
+ if (cost < bestCost) { bestCost = cost; bestIdx = i; }
}
if (bestIdx >= 0) {
@@ -112,6 +153,7 @@ function trackPartials(frames) {
partial.times.push(frame.time);
partial.freqs.push(pk.freq);
partial.amps.push(pk.amp);
+ partial.phases.push(pk.phase);
partial.age = 0;
partial.velocity = pk.freq - lastFreq;
matched.add(bestIdx);
@@ -120,20 +162,29 @@ function trackPartials(frames) {
}
}
- // Advance candidates
+ // --- Advance candidates ---
for (let i = candidates.length - 1; i >= 0; --i) {
const cand = candidates[i];
const lastFreq = cand.freqs[cand.freqs.length - 1];
+ const lastPhase = cand.phases[cand.phases.length - 1];
const velocity = cand.velocity || 0;
- const predicted = lastFreq + velocity;
+ const predictedFreq = lastFreq + velocity;
- const tol = Math.max(lastFreq * trackingRatio, minTrackingHz);
- let bestIdx = -1, bestDist = Infinity;
+ const phaseAdvance = 2 * Math.PI * lastFreq * hopSize / sampleRate;
+ const predictedPhase = lastPhase + phaseAdvance;
+
+ const tol = Math.max(predictedFreq * trackingRatio, minTrackingHz);
+ let bestIdx = -1, bestCost = Infinity;
for (let j = 0; j < frame.peaks.length; ++j) {
if (matched.has(j)) continue;
- const dist = Math.abs(frame.peaks[j].freq - predicted);
- if (dist < tol && dist < bestDist) { bestDist = dist; bestIdx = j; }
+ const pk = frame.peaks[j];
+ const freqError = Math.abs(pk.freq - predictedFreq);
+ if (freqError > tol) continue;
+
+ const phaseError = Math.abs(normalizeAngle(pk.phase - predictedPhase));
+ const cost = freqError + phaseErrorWeight * phaseError * predictedFreq;
+ if (cost < bestCost) { bestCost = cost; bestIdx = j; }
}
if (bestIdx >= 0) {
@@ -141,31 +192,34 @@ function trackPartials(frames) {
cand.times.push(frame.time);
cand.freqs.push(pk.freq);
cand.amps.push(pk.amp);
+ cand.phases.push(pk.phase);
cand.velocity = pk.freq - lastFreq;
matched.add(bestIdx);
+ // "graduate" a candidate to a full partial
if (cand.times.length >= birthPersistence) {
activePartials.push(cand);
candidates.splice(i, 1);
}
} else {
- candidates.splice(i, 1);
+ candidates.splice(i, 1); // kill candidate
}
}
- // Spawn candidates from unmatched peaks
+ // --- Spawn new candidates from unmatched peaks ---
for (let i = 0; i < frame.peaks.length; ++i) {
if (matched.has(i)) continue;
const pk = frame.peaks[i];
candidates.push({
times: [frame.time],
freqs: [pk.freq],
- amps: [pk.amp],
+ amps: [pk.amp],
+ phases: [pk.phase],
age: 0,
velocity: 0
});
}
- // Kill aged-out partials
+ // --- Kill aged-out partials ---
for (let i = activePartials.length - 1; i >= 0; --i) {
if (activePartials[i].age > deathAge) {
if (activePartials[i].times.length >= minLength) partials.push(activePartials[i]);
@@ -174,7 +228,7 @@ function trackPartials(frames) {
}
}
- // Collect remaining active partials
+ // --- Collect remaining active partials ---
for (const partial of activePartials) {
if (partial.times.length >= minLength) partials.push(partial);
}
@@ -193,6 +247,8 @@ function expandPartialsLeft(partials, frames) {
for (let i = 0; i < frames.length; ++i) timeToIdx.set(frames[i].time, i);
for (const partial of partials) {
+ if (!partial.phases) partial.phases = []; // Ensure old partials have phase array
+
let startIdx = timeToIdx.get(partial.times[0]);
if (startIdx == null || startIdx === 0) continue;
@@ -213,6 +269,7 @@ function expandPartialsLeft(partials, frames) {
partial.times.unshift(frame.time);
partial.freqs.unshift(pk.freq);
partial.amps.unshift(pk.amp);
+ partial.phases.unshift(pk.phase);
}
}
}