//! Similarity search: find samples with similar audio features using weighted Euclidean distance. use crate::db::Database; use crate::error::{CoreError, Result}; use crate::vp_tree::VpTree; use tracing::instrument; /// Audio feature vector for similarity comparison. /// /// Each field is optional — dimensions with `None` are skipped during distance calculation. #[derive(Debug, Clone, Default)] pub struct FeatureVector { pub bpm: Option, pub duration: Option, pub lufs: Option, pub spectral_centroid: Option, pub spectral_flatness: Option, pub spectral_rolloff: Option, pub zero_crossing_rate: Option, pub onset_strength: Option, pub spectral_bandwidth: Option, pub centroid_variance: Option, pub crest_factor: Option, pub attack_time: Option, } /// Per-dimension weights for the distance function. #[derive(Debug, Clone)] pub struct FeatureWeights { pub bpm: f64, pub duration: f64, pub lufs: f64, pub spectral_centroid: f64, pub spectral_flatness: f64, pub spectral_rolloff: f64, pub zero_crossing_rate: f64, pub onset_strength: f64, pub spectral_bandwidth: f64, pub centroid_variance: f64, pub crest_factor: f64, pub attack_time: f64, } impl Default for FeatureWeights { fn default() -> Self { Self { bpm: 1.0, duration: 1.0, lufs: 1.0, spectral_centroid: 1.0, spectral_flatness: 1.0, spectral_rolloff: 1.0, zero_crossing_rate: 1.0, onset_strength: 1.0, spectral_bandwidth: 1.0, centroid_variance: 1.0, crest_factor: 1.0, attack_time: 1.0, } } } /// A similarity search result: a sample hash with its distance from the reference. #[derive(Debug, Clone)] pub struct SimilarResult { pub hash: String, pub distance: f64, } /// Normalization ranges for each feature dimension, learned from the dataset. #[derive(Debug, Clone, Default)] struct NormRanges { bpm: (f64, f64), duration: (f64, f64), lufs: (f64, f64), spectral_centroid: (f64, f64), spectral_flatness: (f64, f64), spectral_rolloff: (f64, f64), zero_crossing_rate: (f64, f64), onset_strength: (f64, f64), spectral_bandwidth: (f64, f64), centroid_variance: (f64, f64), crest_factor: (f64, f64), attack_time: (f64, f64), } /// Normalize a value to 0..1 given a min/max range. Returns 0.0 if range is zero. fn norm(val: f64, min: f64, max: f64) -> f64 { if (max - min).abs() < f64::EPSILON { 0.0 } else { (val - min) / (max - min) } } /// Normalize a feature vector using the given ranges. fn normalize(fv: &FeatureVector, ranges: &NormRanges) -> FeatureVector { FeatureVector { bpm: fv.bpm.map(|v| norm(v, ranges.bpm.0, ranges.bpm.1)), duration: fv.duration.map(|v| norm(v, ranges.duration.0, ranges.duration.1)), lufs: fv.lufs.map(|v| norm(v, ranges.lufs.0, ranges.lufs.1)), spectral_centroid: fv.spectral_centroid.map(|v| norm(v, ranges.spectral_centroid.0, ranges.spectral_centroid.1)), spectral_flatness: fv.spectral_flatness.map(|v| norm(v, ranges.spectral_flatness.0, ranges.spectral_flatness.1)), spectral_rolloff: fv.spectral_rolloff.map(|v| norm(v, ranges.spectral_rolloff.0, ranges.spectral_rolloff.1)), zero_crossing_rate: fv.zero_crossing_rate.map(|v| norm(v, ranges.zero_crossing_rate.0, ranges.zero_crossing_rate.1)), onset_strength: fv.onset_strength.map(|v| norm(v, ranges.onset_strength.0, ranges.onset_strength.1)), spectral_bandwidth: fv.spectral_bandwidth.map(|v| norm(v, ranges.spectral_bandwidth.0, ranges.spectral_bandwidth.1)), centroid_variance: fv.centroid_variance.map(|v| norm(v, ranges.centroid_variance.0, ranges.centroid_variance.1)), crest_factor: fv.crest_factor.map(|v| norm(v, ranges.crest_factor.0, ranges.crest_factor.1)), attack_time: fv.attack_time.map(|v| norm(v, ranges.attack_time.0, ranges.attack_time.1)), } } /// Compute weighted Euclidean distance between two feature vectors. /// Skips dimensions where either vector has `None`. pub fn feature_distance(a: &FeatureVector, b: &FeatureVector, weights: &FeatureWeights) -> f64 { let mut sum = 0.0; let mut dims = 0.0; let pairs: [(Option, Option, f64); 12] = [ (a.bpm, b.bpm, weights.bpm), (a.duration, b.duration, weights.duration), (a.lufs, b.lufs, weights.lufs), (a.spectral_centroid, b.spectral_centroid, weights.spectral_centroid), (a.spectral_flatness, b.spectral_flatness, weights.spectral_flatness), (a.spectral_rolloff, b.spectral_rolloff, weights.spectral_rolloff), (a.zero_crossing_rate, b.zero_crossing_rate, weights.zero_crossing_rate), (a.onset_strength, b.onset_strength, weights.onset_strength), (a.spectral_bandwidth, b.spectral_bandwidth, weights.spectral_bandwidth), (a.centroid_variance, b.centroid_variance, weights.centroid_variance), (a.crest_factor, b.crest_factor, weights.crest_factor), (a.attack_time, b.attack_time, weights.attack_time), ]; for (va, vb, w) in &pairs { if let (Some(va), Some(vb)) = (va, vb) { let diff = va - vb; sum += w * diff * diff; dims += 1.0; } } // Return a large finite distance when no dimensions overlap, preserving // the VP-tree triangle inequality contract (INFINITY violates it). if dims == 0.0 { 1e10 } else { (sum / dims).sqrt() } } /// Load the feature vector for a sample by hash. #[instrument(skip_all)] pub fn load_features(db: &Database, hash: &str) -> Result { db.conn() .query_row( "SELECT bpm, duration, lufs, spectral_centroid, spectral_flatness, spectral_rolloff, zero_crossing_rate, onset_strength, spectral_bandwidth, centroid_variance, crest_factor, attack_time FROM audio_analysis WHERE hash = ?1", [hash], |row| { Ok(FeatureVector { bpm: row.get(0)?, duration: row.get(1)?, lufs: row.get(2)?, spectral_centroid: row.get(3)?, spectral_flatness: row.get(4)?, spectral_rolloff: row.get(5)?, zero_crossing_rate: row.get(6)?, onset_strength: row.get(7)?, spectral_bandwidth: row.get(8)?, centroid_variance: row.get(9)?, crest_factor: row.get(10)?, attack_time: row.get(11)?, }) }, ) .map_err(|_| CoreError::SampleNotFound(hash.to_string())) } /// Find samples similar to the given reference hash, ranked by feature distance (linear scan). /// /// Loads all analysed samples, normalizes features across the dataset, and returns /// the top `limit` results (excluding the reference itself). /// /// O(n) linear scan with per-query normalization. For indexed sub-linear queries, /// use [`SimilarityIndex`]. #[instrument(skip_all, fields(hash = %hash))] pub fn find_similar(db: &Database, hash: &str, limit: usize) -> Result> { let ref_features = load_features(db, hash)?; // Load all features let mut stmt = db.conn().prepare( "SELECT hash, bpm, duration, lufs, spectral_centroid, spectral_flatness, spectral_rolloff, zero_crossing_rate, onset_strength, spectral_bandwidth, centroid_variance, crest_factor, attack_time FROM audio_analysis WHERE hash != ?1", )?; let all: Vec<(String, FeatureVector)> = stmt .query_map([hash], |row| { Ok(( row.get::<_, String>(0)?, FeatureVector { bpm: row.get(1)?, duration: row.get(2)?, lufs: row.get(3)?, spectral_centroid: row.get(4)?, spectral_flatness: row.get(5)?, spectral_rolloff: row.get(6)?, zero_crossing_rate: row.get(7)?, onset_strength: row.get(8)?, spectral_bandwidth: row.get(9)?, centroid_variance: row.get(10)?, crest_factor: row.get(11)?, attack_time: row.get(12)?, }, )) })? .collect::, _>>()?; if all.is_empty() { return Ok(Vec::new()); } // Compute normalization ranges across all samples + reference let ranges = compute_ranges(&ref_features, &all); let ref_norm = normalize(&ref_features, &ranges); let weights = FeatureWeights::default(); let mut results: Vec = all .iter() .map(|(h, fv)| { let fv_norm = normalize(fv, &ranges); SimilarResult { hash: h.clone(), distance: feature_distance(&ref_norm, &fv_norm, &weights), } }) .collect(); results.sort_by(|a, b| a.distance.total_cmp(&b.distance)); results.truncate(limit); Ok(results) } /// Compute min/max ranges for each feature across the reference and all other samples. fn compute_ranges(reference: &FeatureVector, others: &[(String, FeatureVector)]) -> NormRanges { let mut ranges = NormRanges::default(); // Helper macro to update ranges for a field macro_rules! update_range { ($field:ident, $range_field:ident) => { let mut min = f64::MAX; let mut max = f64::MIN; if let Some(v) = reference.$field { min = min.min(v); max = max.max(v); } for (_, fv) in others { if let Some(v) = fv.$field { min = min.min(v); max = max.max(v); } } if min < f64::MAX { ranges.$range_field = (min, max); } }; } update_range!(bpm, bpm); update_range!(duration, duration); update_range!(lufs, lufs); update_range!(spectral_centroid, spectral_centroid); update_range!(spectral_flatness, spectral_flatness); update_range!(spectral_rolloff, spectral_rolloff); update_range!(zero_crossing_rate, zero_crossing_rate); update_range!(onset_strength, onset_strength); update_range!(spectral_bandwidth, spectral_bandwidth); update_range!(centroid_variance, centroid_variance); update_range!(crest_factor, crest_factor); update_range!(attack_time, attack_time); ranges } // --- VP-tree index for fast similarity search --- /// Entry stored in the VP-tree: hash + pre-normalized feature vector. struct SimilarityEntry { hash: String, features: FeatureVector, } /// Weighted Euclidean distance on pre-normalized features (default weights). fn entry_distance(a: &SimilarityEntry, b: &SimilarityEntry) -> f64 { feature_distance(&a.features, &b.features, &DEFAULT_WEIGHTS) } /// Default weights (static to avoid repeated allocation). const DEFAULT_WEIGHTS: FeatureWeights = FeatureWeights { bpm: 1.0, duration: 1.0, lufs: 1.0, spectral_centroid: 1.0, spectral_flatness: 1.0, spectral_rolloff: 1.0, zero_crossing_rate: 1.0, onset_strength: 1.0, spectral_bandwidth: 1.0, centroid_variance: 1.0, crest_factor: 1.0, attack_time: 1.0, }; /// Pre-built VP-tree index for fast similarity search. /// /// Normalization ranges are computed once from the full dataset and cached. /// Query vectors are normalized with these fixed ranges. This is a minor /// semantic change from per-query normalization — a single new sample no /// longer shifts the normalization of all others — but the ranking impact /// is negligible for libraries of any practical size. /// /// Build cost: O(n log n) with cheap Euclidean distance (< 0.5s for 100K). /// Query cost: O(log n) — sub-millisecond for 100K. pub struct SimilarityIndex { tree: VpTree, ranges: NormRanges, } impl SimilarityIndex { /// Build an index from all analysed samples in the database. #[instrument(skip_all)] /// Load raw feature data from the database (fast, just I/O). pub fn load_data(db: &Database) -> Result> { let mut stmt = db.conn().prepare( "SELECT hash, bpm, duration, lufs, spectral_centroid, spectral_flatness, spectral_rolloff, zero_crossing_rate, onset_strength, spectral_bandwidth, centroid_variance, crest_factor, attack_time FROM audio_analysis", )?; let all: Vec<(String, FeatureVector)> = stmt .query_map([], |row| { Ok(( row.get::<_, String>(0)?, FeatureVector { bpm: row.get(1)?, duration: row.get(2)?, lufs: row.get(3)?, spectral_centroid: row.get(4)?, spectral_flatness: row.get(5)?, spectral_rolloff: row.get(6)?, zero_crossing_rate: row.get(7)?, onset_strength: row.get(8)?, spectral_bandwidth: row.get(9)?, centroid_variance: row.get(10)?, crest_factor: row.get(11)?, attack_time: row.get(12)?, }, )) })? .collect::, _>>()?; Ok(all) } /// Build the index from pre-loaded data (CPU-intensive, no DB needed). pub fn build_from_data(all: Vec<(String, FeatureVector)>) -> Self { if all.is_empty() { return Self { tree: VpTree::build(vec![], entry_distance), ranges: NormRanges::default(), }; } let ranges = compute_ranges_all(&all); let entries: Vec = all .into_iter() .map(|(hash, fv)| SimilarityEntry { hash, features: normalize(&fv, &ranges), }) .collect(); let tree = VpTree::build(entries, entry_distance); Self { tree, ranges } } pub fn build(db: &Database) -> Result { let all = Self::load_data(db)?; Ok(Self::build_from_data(all)) } /// Number of samples in the index. pub fn len(&self) -> usize { self.tree.len() } /// Whether the index is empty. pub fn is_empty(&self) -> bool { self.tree.is_empty() } /// Find samples similar to the given features, ranked by distance. #[instrument(skip_all)] pub fn find_similar( &self, hash: &str, features: &FeatureVector, limit: usize, ) -> Vec { let query = SimilarityEntry { hash: hash.to_string(), features: normalize(features, &self.ranges), }; // Request limit+1 to account for self being in the tree. let candidates = self.tree.find_nearest(&query, limit + 1, entry_distance); candidates .into_iter() .filter(|c| self.tree.get(c.index).hash != hash) .take(limit) .map(|c| SimilarResult { hash: self.tree.get(c.index).hash.clone(), distance: c.distance, }) .collect() } } /// Compute min/max ranges across all samples (no reference bias). fn compute_ranges_all(samples: &[(String, FeatureVector)]) -> NormRanges { let mut ranges = NormRanges::default(); macro_rules! update_range { ($field:ident, $range_field:ident) => { let mut min = f64::MAX; let mut max = f64::MIN; for (_, fv) in samples { if let Some(v) = fv.$field { min = min.min(v); max = max.max(v); } } if min < f64::MAX { ranges.$range_field = (min, max); } }; } update_range!(bpm, bpm); update_range!(duration, duration); update_range!(lufs, lufs); update_range!(spectral_centroid, spectral_centroid); update_range!(spectral_flatness, spectral_flatness); update_range!(spectral_rolloff, spectral_rolloff); update_range!(zero_crossing_rate, zero_crossing_rate); update_range!(onset_strength, onset_strength); update_range!(spectral_bandwidth, spectral_bandwidth); update_range!(centroid_variance, centroid_variance); update_range!(crest_factor, crest_factor); update_range!(attack_time, attack_time); ranges } #[cfg(test)] mod tests { use super::*; use crate::test_helpers::insert_fake_sample; use crate::analysis::{self, AnalysisResult}; fn insert_with_features(db: &Database, hash: &str, bpm: f64, duration: f64) { insert_fake_sample(db, hash); let result = AnalysisResult { hash: hash.to_string(), duration, sample_rate: 44100, channels: 1, peak_db: None, rms_db: None, lufs: Some(-14.0), bpm: Some(bpm), musical_key: None, is_loop: None, spectral_centroid: Some(1000.0), spectral_flatness: Some(0.5), spectral_rolloff: Some(5000.0), zero_crossing_rate: Some(0.1), onset_strength: Some(20.0), classification: None, fingerprint: None, spectral_bandwidth: Some(2000.0), centroid_variance: Some(50000.0), crest_factor: Some(3.0), attack_time: Some(0.01), classification_confidence: None, }; analysis::save_analysis(db, &result).unwrap(); } #[test] fn normalize_values() { let fv = FeatureVector { bpm: Some(120.0), duration: Some(2.0), ..Default::default() }; let ranges = NormRanges { bpm: (100.0, 200.0), duration: (1.0, 3.0), ..Default::default() }; let normed = normalize(&fv, &ranges); assert!((normed.bpm.unwrap() - 0.2).abs() < 1e-10); assert!((normed.duration.unwrap() - 0.5).abs() < 1e-10); } #[test] fn distance_zero_for_identical() { let fv = FeatureVector { bpm: Some(0.5), duration: Some(0.5), lufs: Some(0.5), spectral_centroid: Some(0.5), spectral_flatness: Some(0.5), spectral_rolloff: Some(0.5), zero_crossing_rate: Some(0.5), onset_strength: Some(0.5), spectral_bandwidth: Some(0.5), centroid_variance: Some(0.5), crest_factor: Some(0.5), attack_time: Some(0.5), }; let d = feature_distance(&fv, &fv, &FeatureWeights::default()); assert!((d - 0.0).abs() < f64::EPSILON); } #[test] fn distance_symmetric() { let a = FeatureVector { bpm: Some(0.0), duration: Some(1.0), ..Default::default() }; let b = FeatureVector { bpm: Some(1.0), duration: Some(0.0), ..Default::default() }; let w = FeatureWeights::default(); let d1 = feature_distance(&a, &b, &w); let d2 = feature_distance(&b, &a, &w); assert!((d1 - d2).abs() < f64::EPSILON); } #[test] fn ranking_correctness() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "ref", 120.0, 1.0); insert_with_features(&db, "close", 122.0, 1.1); insert_with_features(&db, "far", 200.0, 10.0); let results = find_similar(&db, "ref", 10).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].hash, "close"); assert_eq!(results[1].hash, "far"); assert!(results[0].distance < results[1].distance); } #[test] fn limit_respected() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "ref", 120.0, 1.0); insert_with_features(&db, "a", 121.0, 1.0); insert_with_features(&db, "b", 122.0, 1.0); insert_with_features(&db, "c", 123.0, 1.0); let results = find_similar(&db, "ref", 2).unwrap(); assert_eq!(results.len(), 2); } #[test] fn missing_hash_errors() { let db = Database::open_in_memory().unwrap(); let result = find_similar(&db, "nonexistent", 10); assert!(result.is_err()); } // --- SimilarityIndex tests --- #[test] fn index_build_empty() { let db = Database::open_in_memory().unwrap(); let idx = SimilarityIndex::build(&db).unwrap(); assert!(idx.is_empty()); assert_eq!(idx.len(), 0); } #[test] fn index_ranking_matches_linear() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "ref", 120.0, 1.0); insert_with_features(&db, "close", 122.0, 1.1); insert_with_features(&db, "far", 200.0, 10.0); let linear = find_similar(&db, "ref", 10).unwrap(); let idx = SimilarityIndex::build(&db).unwrap(); let ref_features = load_features(&db, "ref").unwrap(); let indexed = idx.find_similar("ref", &ref_features, 10); // Same ranking order. assert_eq!(linear.len(), indexed.len()); for (l, i) in linear.iter().zip(indexed.iter()) { assert_eq!(l.hash, i.hash, "Ranking order differs"); } } #[test] fn index_limit_respected() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "ref", 120.0, 1.0); insert_with_features(&db, "a", 121.0, 1.0); insert_with_features(&db, "b", 122.0, 1.0); insert_with_features(&db, "c", 123.0, 1.0); let idx = SimilarityIndex::build(&db).unwrap(); let ref_features = load_features(&db, "ref").unwrap(); let results = idx.find_similar("ref", &ref_features, 2); assert_eq!(results.len(), 2); } #[test] fn index_excludes_self() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "only", 120.0, 1.0); let idx = SimilarityIndex::build(&db).unwrap(); let features = load_features(&db, "only").unwrap(); let results = idx.find_similar("only", &features, 10); assert!(results.is_empty()); } #[test] fn index_sorted_by_distance() { let db = Database::open_in_memory().unwrap(); insert_with_features(&db, "ref", 120.0, 1.0); insert_with_features(&db, "a", 125.0, 2.0); insert_with_features(&db, "b", 130.0, 3.0); insert_with_features(&db, "c", 140.0, 5.0); insert_with_features(&db, "d", 200.0, 10.0); let idx = SimilarityIndex::build(&db).unwrap(); let ref_features = load_features(&db, "ref").unwrap(); let results = idx.find_similar("ref", &ref_features, 10); for w in results.windows(2) { assert!( w[0].distance <= w[1].distance, "Results not sorted: {} > {}", w[0].distance, w[1].distance ); } } }