Skip to main content

max / audiofiles

8.0 KB · 257 lines History Blame Raw
1 //! MFCC (Mel-Frequency Cepstral Coefficients) from FFT magnitude frames.
2 //!
3 //! Computed from existing STFT magnitudes — no additional FFT needed:
4 //! mel filterbank (26 bands) -> log energy -> DCT-II -> first 13 coefficients.
5 //! Aggregated as mean + variance across frames.
6
7 /// Number of mel filterbank bands.
8 const NUM_MEL_FILTERS: usize = 26;
9 /// Number of MFCC coefficients to keep from DCT.
10 pub const NUM_MFCC: usize = 13;
11
12 /// Aggregated MFCC features across all frames.
13 #[derive(Debug, Clone)]
14 pub struct MfccFeatures {
15 pub means: [f64; NUM_MFCC],
16 pub variances: [f64; NUM_MFCC],
17 }
18
19 impl Default for MfccFeatures {
20 fn default() -> Self {
21 Self {
22 means: [0.0; NUM_MFCC],
23 variances: [0.0; NUM_MFCC],
24 }
25 }
26 }
27
28 /// Sparse representation of a triangular mel filter.
29 struct MelFilter {
30 start_bin: usize,
31 weights: Vec<f64>,
32 }
33
34 /// Precomputed mel filterbank for a given FFT size and sample rate.
35 struct MelFilterbank {
36 filters: Vec<MelFilter>,
37 }
38
39 /// Convert frequency in Hz to the mel scale: `2595 * log10(1 + hz/700)`.
40 /// The mel scale approximates human pitch perception — equal mel intervals
41 /// sound equally spaced to the ear, even though the Hz intervals grow wider
42 /// at higher frequencies.
43 fn hz_to_mel(hz: f64) -> f64 {
44 2595.0 * (1.0 + hz / 700.0).log10()
45 }
46
47 /// Inverse mel scale: convert mel value back to Hz.
48 fn mel_to_hz(mel: f64) -> f64 {
49 700.0 * (10.0_f64.powf(mel / 2595.0) - 1.0)
50 }
51
52 /// Build a mel-spaced triangular filterbank for MFCC computation.
53 ///
54 /// Creates `num_filters` overlapping triangular filters spanning 0 Hz to Nyquist,
55 /// spaced uniformly on the mel scale. Each filter has a rising slope from its left
56 /// edge to its center and a falling slope from center to right edge — this smoothly
57 /// bins FFT magnitudes into perceptually meaningful frequency bands.
58 fn build_mel_filterbank(fft_size: usize, sample_rate: u32, num_filters: usize) -> MelFilterbank {
59 let spectrum_len = fft_size / 2 + 1;
60 let freq_bin_hz = sample_rate as f64 / fft_size as f64;
61
62 let mel_low = hz_to_mel(0.0);
63 let mel_high = hz_to_mel(sample_rate as f64 / 2.0);
64
65 // num_filters + 2 points define the edges of all triangular filters.
66 let num_points = num_filters + 2;
67 let mel_points: Vec<f64> = (0..num_points)
68 .map(|i| mel_low + (mel_high - mel_low) * i as f64 / (num_points - 1) as f64)
69 .collect();
70
71 let hz_points: Vec<f64> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
72 let bin_points: Vec<usize> = hz_points
73 .iter()
74 .map(|&hz| ((hz / freq_bin_hz).round() as usize).min(spectrum_len - 1))
75 .collect();
76
77 let mut filters = Vec::with_capacity(num_filters);
78
79 for i in 0..num_filters {
80 let start = bin_points[i];
81 let center = bin_points[i + 1];
82 let end = bin_points[i + 2];
83
84 if start >= end || center <= start || center >= end {
85 filters.push(MelFilter {
86 start_bin: 0,
87 weights: Vec::new(),
88 });
89 continue;
90 }
91
92 let mut weights = vec![0.0; end - start];
93
94 // Rising slope: start -> center
95 for bin in start..center {
96 weights[bin - start] = (bin - start) as f64 / (center - start) as f64;
97 }
98 // Falling slope: center -> end
99 for bin in center..end {
100 weights[bin - start] = (end - bin) as f64 / (end - center) as f64;
101 }
102
103 filters.push(MelFilter {
104 start_bin: start,
105 weights,
106 });
107 }
108
109 MelFilterbank { filters }
110 }
111
112 /// Compute aggregated MFCC features from per-frame magnitude spectra.
113 ///
114 /// Each frame is a magnitude spectrum from the STFT. Applies a mel filterbank,
115 /// takes log energies, applies DCT-II, and aggregates mean + variance across frames.
116 pub fn compute_mfccs(
117 magnitude_frames: &[Vec<f64>],
118 sample_rate: u32,
119 fft_size: usize,
120 ) -> MfccFeatures {
121 if magnitude_frames.is_empty() {
122 return MfccFeatures::default();
123 }
124
125 let filterbank = build_mel_filterbank(fft_size, sample_rate, NUM_MEL_FILTERS);
126
127 let mut all_mfccs: Vec<[f64; NUM_MFCC]> = Vec::with_capacity(magnitude_frames.len());
128
129 for frame in magnitude_frames {
130 // Apply mel filterbank
131 let mut mel_energies = [0.0f64; NUM_MEL_FILTERS];
132 for (i, filter) in filterbank.filters.iter().enumerate() {
133 let mut energy = 0.0;
134 for (j, &w) in filter.weights.iter().enumerate() {
135 let bin = filter.start_bin + j;
136 if bin < frame.len() {
137 energy += w * frame[bin];
138 }
139 }
140 mel_energies[i] = energy.max(1e-10);
141 }
142
143 // Log energies
144 let log_energies: [f64; NUM_MEL_FILTERS] =
145 std::array::from_fn(|i| mel_energies[i].ln());
146
147 // DCT-II: keep first NUM_MFCC coefficients
148 let n = NUM_MEL_FILTERS as f64;
149 let mut mfccs = [0.0f64; NUM_MFCC];
150 for (k, coeff) in mfccs.iter_mut().enumerate() {
151 let mut sum = 0.0;
152 for (i, &le) in log_energies.iter().enumerate() {
153 sum += le
154 * (std::f64::consts::PI * k as f64 * (2.0 * i as f64 + 1.0) / (2.0 * n))
155 .cos();
156 }
157 *coeff = sum * 2.0;
158 }
159
160 all_mfccs.push(mfccs);
161 }
162
163 // Mean across frames
164 let num_frames = all_mfccs.len() as f64;
165 let mut means = [0.0f64; NUM_MFCC];
166 for mfccs in &all_mfccs {
167 for (i, &v) in mfccs.iter().enumerate() {
168 means[i] += v;
169 }
170 }
171 for m in &mut means {
172 *m /= num_frames;
173 }
174
175 // Variance across frames
176 let mut variances = [0.0f64; NUM_MFCC];
177 if all_mfccs.len() > 1 {
178 for mfccs in &all_mfccs {
179 for (i, &v) in mfccs.iter().enumerate() {
180 let diff = v - means[i];
181 variances[i] += diff * diff;
182 }
183 }
184 for v in &mut variances {
185 *v /= num_frames;
186 }
187 }
188
189 MfccFeatures { means, variances }
190 }
191
192 #[cfg(test)]
193 mod tests {
194 use super::*;
195
196 #[test]
197 fn hz_mel_roundtrip() {
198 for &hz in &[0.0, 100.0, 440.0, 1000.0, 8000.0, 22050.0] {
199 let mel = hz_to_mel(hz);
200 let back = mel_to_hz(mel);
201 assert!(
202 (hz - back).abs() < 0.01,
203 "roundtrip failed for {hz}: got {back}"
204 );
205 }
206 }
207
208 #[test]
209 fn filterbank_has_correct_count() {
210 let fb = build_mel_filterbank(1024, 44100, NUM_MEL_FILTERS);
211 assert_eq!(fb.filters.len(), NUM_MEL_FILTERS);
212 }
213
214 #[test]
215 fn empty_frames_returns_default() {
216 let result = compute_mfccs(&[], 44100, 1024);
217 assert_eq!(result.means, [0.0; NUM_MFCC]);
218 assert_eq!(result.variances, [0.0; NUM_MFCC]);
219 }
220
221 #[test]
222 fn single_frame_has_zero_variance() {
223 let frame = vec![1.0; 513]; // spectrum_len for fft_size=1024
224 let result = compute_mfccs(&[frame], 44100, 1024);
225 assert_eq!(result.variances, [0.0; NUM_MFCC]);
226 }
227
228 #[test]
229 fn different_frames_have_nonzero_variance() {
230 let frame1 = vec![1.0; 513];
231 let frame2 = vec![2.0; 513];
232 let result = compute_mfccs(&[frame1, frame2], 44100, 1024);
233 assert!(result.variances.iter().any(|&v| v > 0.0));
234 }
235
236 #[test]
237 fn mel_scale_is_monotonic() {
238 let freqs = [100.0, 200.0, 500.0, 1000.0, 2000.0, 5000.0, 10000.0];
239 for pair in freqs.windows(2) {
240 assert!(hz_to_mel(pair[1]) > hz_to_mel(pair[0]));
241 }
242 }
243
244 #[test]
245 fn mfcc_values_are_finite() {
246 // Random-ish magnitude spectrum
247 let frame: Vec<f64> = (0..513).map(|i| (i as f64 * 0.01).sin().abs() + 0.001).collect();
248 let result = compute_mfccs(&[frame.clone(), frame], 44100, 1024);
249 for &m in &result.means {
250 assert!(m.is_finite(), "non-finite mean: {m}");
251 }
252 for &v in &result.variances {
253 assert!(v.is_finite(), "non-finite variance: {v}");
254 }
255 }
256 }
257