Skip to main content

max / audiofiles

11.0 KB · 306 lines History Blame Raw
1 //! Spectral analysis via STFT: centroid, flatness, rolloff, zero-crossing rate, and onset strength.
2 //!
3 //! The heavy `compute_spectral_features` function (which depends on `realfft`) is gated behind
4 //! the `analysis` feature. The [`SpectralFeatures`] struct is always available so that downstream
5 //! code (classify, AnalysisResult) compiles unconditionally.
6
7 #[cfg(feature = "analysis")]
8 use realfft::RealFftPlanner;
9 use tracing::instrument;
10
11 /// Results of spectral analysis across the entire signal.
12 #[derive(Debug, Clone, Default)]
13 pub struct SpectralFeatures {
14 /// Spectral center of mass in Hz (low = bassy, high = bright).
15 pub centroid: f64,
16 /// Ratio from 0.0 (pure tone) to 1.0 (white noise), computed as geometric/arithmetic
17 /// mean of magnitudes.
18 pub flatness: f64,
19 /// Frequency in Hz below which 85% of spectral energy resides.
20 pub rolloff: f64,
21 /// Proportion of sign changes per sample (high = noisy/bright transients).
22 pub zero_crossing_rate: f64,
23 /// Sum of positive spectral flux across frames (higher = more percussive onsets).
24 pub onset_strength: f64,
25 /// Spectral standard deviation around the centroid in Hz (spread of energy).
26 pub bandwidth: f64,
27 /// Variance of per-frame centroids (high = evolving spectrum, low = static).
28 pub centroid_variance: f64,
29 }
30
31 /// Compute spectral features from mono audio samples.
32 ///
33 /// Uses STFT with Hann window, hop size = window/4.
34 /// All features are averaged across frames.
35 ///
36 /// Requires the `analysis` feature (pulls in `realfft`).
37 #[cfg(feature = "analysis")]
38 #[instrument(skip_all)]
39 pub fn compute_spectral_features(samples: &[f32], sample_rate: u32) -> SpectralFeatures {
40 compute_spectral_inner(samples, sample_rate).0
41 }
42
43 /// Compute spectral features and return per-frame magnitude spectra.
44 ///
45 /// Same as [`compute_spectral_features`] but also returns the magnitude vectors
46 /// from each STFT frame, for downstream use by MFCC computation.
47 #[cfg(feature = "analysis")]
48 #[instrument(skip_all)]
49 pub fn compute_spectral_features_with_frames(
50 samples: &[f32],
51 sample_rate: u32,
52 ) -> (SpectralFeatures, Vec<Vec<f64>>) {
53 compute_spectral_inner(samples, sample_rate)
54 }
55
56 /// Shared implementation for spectral feature extraction.
57 #[cfg(feature = "analysis")]
58 fn compute_spectral_inner(
59 samples: &[f32],
60 sample_rate: u32,
61 ) -> (SpectralFeatures, Vec<Vec<f64>>) {
62 if samples.is_empty() {
63 return (SpectralFeatures::default(), Vec::new());
64 }
65
66 let window_size = 1024;
67 let hop_size = window_size; // no overlap — classification needs stable averages, not frame-level detail
68
69 // Compute ZCR on the raw signal
70 let zcr = zero_crossing_rate(samples);
71
72 // If signal is too short for a single window, return just ZCR
73 if samples.len() < window_size {
74 return (
75 SpectralFeatures {
76 zero_crossing_rate: zcr,
77 ..Default::default()
78 },
79 Vec::new(),
80 );
81 }
82
83 let mut planner = RealFftPlanner::<f64>::new();
84 let fft = planner.plan_fft_forward(window_size);
85
86 // Hann window: w(n) = 0.5 * (1 - cos(2*pi*n / (N-1)))
87 // Smoothly tapers signal to zero at frame edges, reducing spectral leakage in the FFT.
88 let hann: Vec<f64> = (0..window_size)
89 .map(|i| {
90 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (window_size - 1) as f64).cos())
91 })
92 .collect();
93
94 let freq_bin_hz = sample_rate as f64 / window_size as f64;
95 let spectrum_len = window_size / 2 + 1;
96
97 let mut centroids = Vec::new();
98 let mut flatnesses = Vec::new();
99 let mut rolloffs = Vec::new();
100 let mut bandwidths = Vec::new();
101 let mut onset_diffs = Vec::new();
102 let mut magnitude_frames = Vec::new();
103
104 let mut pos = 0;
105 while pos + window_size <= samples.len() {
106 // Apply window
107 let mut windowed: Vec<f64> = samples[pos..pos + window_size]
108 .iter()
109 .enumerate()
110 .map(|(i, &s)| s as f64 * hann[i])
111 .collect();
112
113 // FFT
114 let mut spectrum = fft.make_output_vec();
115 if fft.process(&mut windowed, &mut spectrum).is_err() {
116 pos += hop_size;
117 continue;
118 }
119
120 // Magnitude spectrum
121 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
122
123 // Spectral centroid: frequency-weighted average of the magnitude spectrum.
124 // Sum(bin_freq * magnitude) / Sum(magnitude) — gives the "center of mass" in Hz.
125 // Low centroid = bass-heavy, high centroid = bright/treble-heavy.
126 let mag_sum: f64 = magnitudes.iter().sum();
127 if mag_sum > 0.0 {
128 let weighted_sum: f64 = magnitudes
129 .iter()
130 .enumerate()
131 .map(|(i, &m)| i as f64 * freq_bin_hz * m)
132 .sum();
133 let frame_centroid = weighted_sum / mag_sum;
134 centroids.push(frame_centroid);
135
136 // Spectral bandwidth: standard deviation of the spectrum around the centroid.
137 // sqrt(Sum((freq - centroid)^2 * mag) / Sum(mag)) per frame.
138 let bw_sum: f64 = magnitudes
139 .iter()
140 .enumerate()
141 .map(|(i, &m)| {
142 let freq = i as f64 * freq_bin_hz;
143 let diff = freq - frame_centroid;
144 diff * diff * m
145 })
146 .sum();
147 bandwidths.push((bw_sum / mag_sum).sqrt());
148
149 // Spectral flatness = geometric_mean / arithmetic_mean of the magnitude spectrum.
150 // Measures how "tone-like" vs "noise-like" the spectrum is:
151 // 0.0 = pure tone (energy in one bin), 1.0 = white noise (energy uniform).
152 // Computed in log-space to avoid floating-point underflow: exp(mean(ln(x)))
153 // instead of the direct product. Bins below 1e-10 are floored to avoid ln(0).
154 let n = spectrum_len as f64;
155 let arithmetic_mean = mag_sum / n;
156 let log_sum: f64 = magnitudes
157 .iter()
158 .map(|&m| if m > 1e-10 { m.ln() } else { (1e-10_f64).ln() })
159 .sum();
160 let geometric_mean = (log_sum / n).exp();
161 let flat = if arithmetic_mean > 0.0 {
162 geometric_mean / arithmetic_mean
163 } else {
164 0.0
165 };
166 flatnesses.push(flat.clamp(0.0, 1.0));
167
168 // Spectral rolloff: frequency below which 85% of spectral energy resides.
169 // 85% is the standard threshold (vs 90% or 95%) — it captures the "useful"
170 // bandwidth while ignoring the long tail of high-frequency noise.
171 let threshold = mag_sum * 0.85;
172 let mut cumsum = 0.0;
173 let mut rolloff_freq = 0.0;
174 for (i, &m) in magnitudes.iter().enumerate() {
175 cumsum += m;
176 if cumsum >= threshold {
177 rolloff_freq = i as f64 * freq_bin_hz;
178 break;
179 }
180 }
181 rolloffs.push(rolloff_freq);
182
183 // Store this frame, then compute onset strength against the previous
184 // stored frame. `magnitude_frames` already retains every (non-silent)
185 // frame for the later MFCC pass, so the previous spectrum is just the
186 // second-to-last element — no separate `prev_spectrum` copy needed.
187 magnitude_frames.push(magnitudes);
188
189 // Onset strength via spectral flux: sum of positive magnitude increases
190 // between consecutive frames. Only positive differences are counted
191 // (half-wave rectification) because onsets are characterised by energy
192 // appearing, not disappearing.
193 if magnitude_frames.len() >= 2 {
194 let curr = &magnitude_frames[magnitude_frames.len() - 1];
195 let prev = &magnitude_frames[magnitude_frames.len() - 2];
196 let flux: f64 = curr
197 .iter()
198 .zip(prev.iter())
199 .map(|(&curr, &prev)| (curr - prev).max(0.0))
200 .sum();
201 onset_diffs.push(flux);
202 }
203 }
204
205 pos += hop_size;
206 }
207
208 let avg = |v: &[f64]| -> f64 {
209 if v.is_empty() {
210 0.0
211 } else {
212 v.iter().sum::<f64>() / v.len() as f64
213 }
214 };
215
216 // Centroid variance: variance of per-frame centroids (how much the spectrum evolves).
217 let centroid_mean = avg(&centroids);
218 let centroid_var = if centroids.len() < 2 {
219 0.0
220 } else {
221 let sum_sq: f64 = centroids.iter().map(|&c| (c - centroid_mean).powi(2)).sum();
222 sum_sq / centroids.len() as f64
223 };
224
225 let features = SpectralFeatures {
226 centroid: centroid_mean,
227 flatness: avg(&flatnesses),
228 rolloff: avg(&rolloffs),
229 zero_crossing_rate: zcr,
230 onset_strength: avg(&onset_diffs),
231 bandwidth: avg(&bandwidths),
232 centroid_variance: centroid_var,
233 };
234
235 (features, magnitude_frames)
236 }
237
238 /// Fraction of consecutive sample pairs that cross zero.
239 #[cfg(feature = "analysis")]
240 fn zero_crossing_rate(samples: &[f32]) -> f64 {
241 if samples.len() < 2 {
242 return 0.0;
243 }
244 let crossings = samples
245 .windows(2)
246 .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
247 .count();
248 crossings as f64 / (samples.len() - 1) as f64
249 }
250
251 #[cfg(test)]
252 #[cfg(feature = "analysis")]
253 mod tests {
254 use super::*;
255
256 #[test]
257 fn silence_features() {
258 let silence = vec![0.0f32; 4096];
259 let f = compute_spectral_features(&silence, 44100);
260 assert_eq!(f.centroid, 0.0);
261 assert_eq!(f.onset_strength, 0.0);
262 }
263
264 #[test]
265 fn sine_has_low_flatness() {
266 // Pure sine at 440 Hz — should have low spectral flatness (tonal)
267 let samples: Vec<f32> = (0..44100)
268 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 44100.0).sin())
269 .collect();
270 let f = compute_spectral_features(&samples, 44100);
271 assert!(
272 f.flatness < 0.1,
273 "pure sine flatness should be low, got {}",
274 f.flatness
275 );
276 // Centroid should be near 440 Hz
277 assert!(
278 (f.centroid - 440.0).abs() < 100.0,
279 "centroid should be near 440 Hz, got {}",
280 f.centroid
281 );
282 }
283
284 #[test]
285 fn zcr_of_high_freq_is_higher() {
286 let low: Vec<f32> = (0..44100)
287 .map(|i| (2.0 * std::f32::consts::PI * 100.0 * i as f32 / 44100.0).sin())
288 .collect();
289 let high: Vec<f32> = (0..44100)
290 .map(|i| (2.0 * std::f32::consts::PI * 10000.0 * i as f32 / 44100.0).sin())
291 .collect();
292 let zcr_low = zero_crossing_rate(&low);
293 let zcr_high = zero_crossing_rate(&high);
294 assert!(zcr_high > zcr_low);
295 }
296
297 #[test]
298 fn short_signal() {
299 // Shorter than one window
300 let short = vec![0.5f32; 100];
301 let f = compute_spectral_features(&short, 44100);
302 // Should still compute ZCR
303 assert_eq!(f.zero_crossing_rate, 0.0); // constant signal, no crossings
304 }
305 }
306