Skip to main content

max / audiofiles

15.5 KB · 424 lines History Blame Raw
1 //! Audio analysis pipeline: orchestrates decoding, feature extraction, classification, and DB persistence.
2 //!
3 //! The pipeline decodes audio to mono f32 via Symphonia, then runs configurable
4 //! stages (loudness, spectral features, BPM/key detection, loop detection,
5 //! classification) and persists results to the `audio_analysis` table.
6
7 pub mod basic;
8 pub mod bpm;
9 pub mod classify;
10 pub mod config;
11 pub mod decode;
12 pub mod loop_detect;
13 #[cfg(feature = "analysis")]
14 pub mod loudness;
15 pub mod mfcc;
16 pub mod spectral;
17 pub mod suggest;
18 pub mod waveform;
19 pub mod worker;
20
21 use std::path::Path;
22
23 use classify::SampleClass;
24 use config::AnalysisConfig;
25
26 use crate::db::Database;
27 use crate::error::{unix_now, CoreError};
28 use crate::fingerprint;
29 use tracing::instrument;
30
31 /// Complete analysis result for a single sample.
32 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
33 pub struct AnalysisResult {
34 /// Content-addressed hash identifying this sample in the store.
35 pub hash: String,
36 /// Total duration in seconds.
37 pub duration: f64,
38 /// Sample rate in Hz.
39 pub sample_rate: u32,
40 /// Number of channels in the source file.
41 pub channels: u16,
42 /// Peak amplitude in dBFS.
43 pub peak_db: Option<f64>,
44 /// RMS loudness in dBFS.
45 pub rms_db: Option<f64>,
46 /// Integrated loudness in LUFS.
47 pub lufs: Option<f64>,
48 /// Estimated tempo in beats per minute.
49 pub bpm: Option<f64>,
50 /// Estimated musical key (e.g. "A minor").
51 pub musical_key: Option<String>,
52 /// Whether the sample is detected as a seamless loop.
53 pub is_loop: Option<bool>,
54 /// Spectral centroid in Hz (brightness measure).
55 pub spectral_centroid: Option<f64>,
56 /// Spectral flatness (0 = tonal, 1 = noise-like).
57 pub spectral_flatness: Option<f64>,
58 /// Spectral rolloff frequency in Hz.
59 pub spectral_rolloff: Option<f64>,
60 /// Zero-crossing rate (proportion of sign changes per sample).
61 pub zero_crossing_rate: Option<f64>,
62 /// Onset detection strength.
63 pub onset_strength: Option<f64>,
64 /// Heuristic sample classification (kick, snare, pad, etc.).
65 pub classification: Option<SampleClass>,
66 /// Peak envelope fingerprint for near-duplicate detection.
67 pub fingerprint: Option<Vec<u8>>,
68 /// Spectral bandwidth in Hz (spread of energy around centroid).
69 pub spectral_bandwidth: Option<f64>,
70 /// Variance of per-frame spectral centroids (spectral evolution).
71 pub centroid_variance: Option<f64>,
72 /// Peak-to-RMS ratio in linear domain (transient sharpness).
73 pub crest_factor: Option<f64>,
74 /// Time to 90% of peak amplitude in seconds (onset speed).
75 pub attack_time: Option<f64>,
76 /// ML classifier confidence (0.0-1.0). 0.0 when using rule-based fallback.
77 pub classification_confidence: Option<f64>,
78 }
79
80 /// Run all configured analyses on a single sample file.
81 #[instrument(skip_all)]
82 pub fn analyze_sample(
83 hash: &str,
84 path: &Path,
85 config: &AnalysisConfig,
86 ) -> Result<AnalysisResult, CoreError> {
87 // Guard against memory exhaustion: reject files over 2 GB before decoding.
88 // A 2 GB compressed file would expand to several GB of f32 samples.
89 const MAX_FILE_SIZE: u64 = 2 * 1024 * 1024 * 1024;
90 if let Ok(metadata) = std::fs::metadata(path)
91 && metadata.len() > MAX_FILE_SIZE {
92 return Err(CoreError::Analysis(crate::error::AnalysisError::ProbeFailed(
93 format!("file too large for analysis ({} MB, max {} MB)",
94 metadata.len() / (1024 * 1024), MAX_FILE_SIZE / (1024 * 1024)),
95 )));
96 }
97
98 let decoded = decode::decode_to_mono(path)?;
99
100 // Hard cap: reject files over 30 minutes to prevent memory exhaustion.
101 // A 30-minute 96kHz mono signal is ~660 MB of f32 — beyond that is almost
102 // certainly not a sample.
103 const MAX_DECODE_DURATION: f64 = 1800.0;
104 if decoded.duration > MAX_DECODE_DURATION {
105 return Err(CoreError::Analysis(crate::error::AnalysisError::ProbeFailed(
106 format!("file too long for analysis ({:.0}s, max {MAX_DECODE_DURATION}s)", decoded.duration),
107 )));
108 }
109
110 // Cap samples for expensive analyses (STFT, BPM/key). Cheap analyses and
111 // fingerprint use the full signal.
112 let capped_samples: &[f32] = if let Some(max_secs) = config.max_analysis_seconds {
113 let max_samples = (max_secs * decoded.sample_rate as f64) as usize;
114 &decoded.samples[..decoded.samples.len().min(max_samples)]
115 } else {
116 &decoded.samples
117 };
118
119 let mut result = AnalysisResult {
120 hash: hash.to_string(),
121 duration: decoded.duration,
122 sample_rate: decoded.sample_rate,
123 channels: decoded.channels,
124 peak_db: None,
125 rms_db: None,
126 lufs: None,
127 bpm: None,
128 musical_key: None,
129 is_loop: None,
130 spectral_centroid: None,
131 spectral_flatness: None,
132 spectral_rolloff: None,
133 zero_crossing_rate: None,
134 onset_strength: None,
135 classification: None,
136 fingerprint: None,
137 spectral_bandwidth: None,
138 centroid_variance: None,
139 crest_factor: None,
140 attack_time: None,
141 classification_confidence: None,
142 };
143
144 // Basic loudness (always fast — uses full signal)
145 if config.loudness {
146 result.peak_db = Some(basic::peak_db(&decoded.samples));
147 result.rms_db = Some(basic::rms_db(&decoded.samples));
148 result.crest_factor = Some(basic::crest_factor(&decoded.samples));
149 result.attack_time = Some(basic::attack_time(&decoded.samples, decoded.sample_rate));
150 #[cfg(feature = "analysis")]
151 {
152 result.lufs =
153 Some(loudness::measure_lufs(&decoded.samples, decoded.sample_rate));
154 }
155 }
156
157 // Spectral features (uses capped samples)
158 #[cfg(feature = "analysis")]
159 if config.spectral {
160 let (features, magnitude_frames) =
161 spectral::compute_spectral_features_with_frames(capped_samples, decoded.sample_rate);
162 result.spectral_centroid = Some(features.centroid);
163 result.spectral_flatness = Some(features.flatness);
164 result.spectral_rolloff = Some(features.rolloff);
165 result.zero_crossing_rate = Some(features.zero_crossing_rate);
166 result.onset_strength = Some(features.onset_strength);
167 result.spectral_bandwidth = Some(features.bandwidth);
168 result.centroid_variance = Some(features.centroid_variance);
169
170 // Classification requires spectral + waveform features
171 if config.classify {
172 // Compute MFCCs from STFT magnitude frames
173 let mfcc_features =
174 mfcc::compute_mfccs(&magnitude_frames, decoded.sample_rate, 1024);
175
176 let input = classify::ClassifyInput::with_mfccs(
177 &features,
178 decoded.duration,
179 result.crest_factor.unwrap_or(0.0),
180 result.attack_time.unwrap_or(0.0),
181 &mfcc_features,
182 );
183
184 let ml_result = classify::classify_ml(&input);
185 result.classification = Some(ml_result.class);
186 result.classification_confidence = Some(ml_result.confidence);
187 }
188 }
189
190 // Smart skip: use classification to decide if BPM/key/loop make sense.
191 // Drums, impacts, noise, foley, ambience, and textures skip these expensive stages.
192 let (want_bpm, want_key, want_loop) = if config.smart_skip {
193 if let Some(ref class) = result.classification {
194 (
195 config.bpm && class.has_rhythm(),
196 config.key && class.has_pitch(),
197 config.loop_detect && class.has_rhythm(),
198 )
199 } else {
200 // No classification result — run everything requested
201 (config.bpm, config.key, config.loop_detect)
202 }
203 } else {
204 (config.bpm, config.key, config.loop_detect)
205 };
206
207 // BPM + key detection (uses capped samples)
208 if want_bpm || want_key {
209 let bpm_key = bpm::detect_bpm_key(capped_samples, decoded.sample_rate, 2.0);
210 if want_bpm {
211 result.bpm = bpm_key.bpm;
212 }
213 if want_key {
214 result.musical_key = bpm_key.key;
215 }
216 }
217
218 // Loop detection
219 if want_loop {
220 result.is_loop = Some(loop_detect::is_loop(
221 &decoded.samples,
222 decoded.sample_rate,
223 result.bpm,
224 ));
225 }
226
227 // Fingerprint for near-duplicate detection (uses full signal)
228 if config.fingerprint {
229 result.fingerprint = Some(fingerprint::compute_envelope(
230 &decoded.samples,
231 decoded.sample_rate,
232 ));
233 }
234
235 Ok(result)
236 }
237
238 /// Save analysis results to the database, overwriting any previous results for this hash.
239 #[instrument(skip_all)]
240 pub fn save_analysis(db: &Database, result: &AnalysisResult) -> Result<(), CoreError> {
241 let now = unix_now();
242 db.conn().execute(
243 "INSERT OR REPLACE INTO audio_analysis (
244 hash, duration, sample_rate, channels,
245 peak_db, rms_db, lufs,
246 bpm, musical_key,
247 is_loop,
248 spectral_centroid, spectral_flatness, spectral_rolloff,
249 zero_crossing_rate, onset_strength,
250 classification,
251 spectral_bandwidth, centroid_variance, crest_factor, attack_time,
252 classification_confidence,
253 analyzed_at
254 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)",
255 rusqlite::params![
256 result.hash,
257 result.duration,
258 result.sample_rate,
259 result.channels,
260 result.peak_db,
261 result.rms_db,
262 result.lufs,
263 result.bpm,
264 result.musical_key,
265 result.is_loop,
266 result.spectral_centroid,
267 result.spectral_flatness,
268 result.spectral_rolloff,
269 result.zero_crossing_rate,
270 result.onset_strength,
271 result.classification.as_ref().map(|c| c.as_str()),
272 result.spectral_bandwidth,
273 result.centroid_variance,
274 result.crest_factor,
275 result.attack_time,
276 result.classification_confidence,
277 now,
278 ],
279 )?;
280
281 if let Some(ref envelope) = result.fingerprint {
282 fingerprint::save_fingerprint(
283 db,
284 &fingerprint::Fingerprint {
285 hash: result.hash.clone(),
286 envelope: envelope.clone(),
287 sample_rate: result.sample_rate,
288 },
289 )?;
290 }
291
292 Ok(())
293 }
294
295 /// Load analysis results for a sample by hash. Returns `None` if no analysis exists.
296 #[instrument(skip_all)]
297 pub fn load_analysis(db: &Database, hash: &str) -> Option<AnalysisResult> {
298 db.conn()
299 .query_row(
300 "SELECT hash, duration, sample_rate, channels, peak_db, rms_db, lufs,
301 bpm, musical_key, is_loop, spectral_centroid, spectral_flatness,
302 spectral_rolloff, zero_crossing_rate, onset_strength, classification,
303 spectral_bandwidth, centroid_variance, crest_factor, attack_time,
304 classification_confidence
305 FROM audio_analysis WHERE hash = ?1",
306 [hash],
307 |row| {
308 let class_str: Option<String> = row.get(15)?;
309 Ok(AnalysisResult {
310 hash: row.get(0)?,
311 duration: row.get(1)?,
312 sample_rate: row.get(2)?,
313 channels: row.get(3)?,
314 peak_db: row.get(4)?,
315 rms_db: row.get(5)?,
316 lufs: row.get(6)?,
317 bpm: row.get(7)?,
318 musical_key: row.get(8)?,
319 is_loop: row.get(9)?,
320 spectral_centroid: row.get(10)?,
321 spectral_flatness: row.get(11)?,
322 spectral_rolloff: row.get(12)?,
323 zero_crossing_rate: row.get(13)?,
324 onset_strength: row.get(14)?,
325 classification: class_str
326 .and_then(|s| s.parse::<classify::SampleClass>().ok()),
327 spectral_bandwidth: row.get(16)?,
328 centroid_variance: row.get(17)?,
329 crest_factor: row.get(18)?,
330 attack_time: row.get(19)?,
331 classification_confidence: row.get(20)?,
332 fingerprint: None,
333 })
334 },
335 )
336 .ok()
337 }
338
339 #[cfg(test)]
340 mod tests {
341 use super::*;
342
343 #[test]
344 fn analysis_result_construction_with_defaults() {
345 let result = AnalysisResult {
346 hash: "abc123".to_string(),
347 duration: 1.5,
348 sample_rate: 48000,
349 channels: 1,
350 peak_db: None,
351 rms_db: None,
352 lufs: None,
353 bpm: None,
354 musical_key: None,
355 is_loop: None,
356 spectral_centroid: None,
357 spectral_flatness: None,
358 spectral_rolloff: None,
359 zero_crossing_rate: None,
360 onset_strength: None,
361 classification: None,
362 fingerprint: None,
363 spectral_bandwidth: None,
364 centroid_variance: None,
365 crest_factor: None,
366 attack_time: None,
367 classification_confidence: None,
368 };
369 assert_eq!(result.hash, "abc123");
370 assert_eq!(result.sample_rate, 48000);
371 assert_eq!(result.channels, 1);
372 assert!((result.duration - 1.5).abs() < f64::EPSILON);
373 assert!(result.peak_db.is_none());
374 assert!(result.bpm.is_none());
375 assert!(result.classification.is_none());
376 assert!(result.spectral_bandwidth.is_none());
377 assert!(result.crest_factor.is_none());
378 }
379
380 #[test]
381 fn analysis_result_fully_populated() {
382 let result = AnalysisResult {
383 hash: "def456".to_string(),
384 duration: 3.2,
385 sample_rate: 44100,
386 channels: 2,
387 peak_db: Some(-0.5),
388 rms_db: Some(-12.0),
389 lufs: Some(-14.0),
390 bpm: Some(128.0),
391 musical_key: Some("A minor".to_string()),
392 is_loop: Some(true),
393 spectral_centroid: Some(1500.0),
394 spectral_flatness: Some(0.15),
395 spectral_rolloff: Some(8000.0),
396 zero_crossing_rate: Some(0.05),
397 onset_strength: Some(30.0),
398 classification: Some(SampleClass::Kick),
399 fingerprint: None,
400 spectral_bandwidth: Some(2500.0),
401 centroid_variance: Some(50000.0),
402 crest_factor: Some(4.5),
403 attack_time: Some(0.005),
404 classification_confidence: Some(0.87),
405 };
406 assert_eq!(result.peak_db, Some(-0.5));
407 assert_eq!(result.rms_db, Some(-12.0));
408 assert_eq!(result.lufs, Some(-14.0));
409 assert_eq!(result.bpm, Some(128.0));
410 assert_eq!(result.musical_key.as_deref(), Some("A minor"));
411 assert_eq!(result.is_loop, Some(true));
412 assert_eq!(result.spectral_centroid, Some(1500.0));
413 assert_eq!(result.spectral_flatness, Some(0.15));
414 assert_eq!(result.spectral_rolloff, Some(8000.0));
415 assert_eq!(result.zero_crossing_rate, Some(0.05));
416 assert_eq!(result.onset_strength, Some(30.0));
417 assert_eq!(result.classification, Some(SampleClass::Kick));
418 assert_eq!(result.spectral_bandwidth, Some(2500.0));
419 assert_eq!(result.centroid_variance, Some(50000.0));
420 assert_eq!(result.crest_factor, Some(4.5));
421 assert_eq!(result.attack_time, Some(0.005));
422 }
423 }
424