//! Search query builder: text search, filter by BPM/key/duration/classification/tags. use crate::db::Database; use crate::error::Result; use crate::id_types::{NodeId, VfsId}; use crate::vfs::{self as vfs_mod, VfsNodeWithAnalysis}; use tracing::instrument; /// Search scope: current folder or global across all VFS roots. #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum SearchScope { CurrentFolder, Global, } /// Key filter mode: exact match or expand to compatible keys. #[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum KeyFilterMode { /// Match only the selected keys. #[default] Exact, /// Expand each selected key to include its compatible keys (relative + adjacent on circle of fifths). Compatible, } /// Filter criteria for searching samples. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(default)] pub struct SearchFilter { pub text_query: String, pub bpm_min: Option, pub bpm_max: Option, pub keys: Vec, pub key_mode: KeyFilterMode, pub duration_min: Option, pub duration_max: Option, pub peak_db_min: Option, pub peak_db_max: Option, pub classifications: Vec, pub required_tags: Vec, pub scope: SearchScope, } impl Default for SearchFilter { fn default() -> Self { Self { text_query: String::new(), bpm_min: None, bpm_max: None, keys: Vec::new(), key_mode: KeyFilterMode::Exact, duration_min: None, duration_max: None, peak_db_min: None, peak_db_max: None, classifications: Vec::new(), required_tags: Vec::new(), scope: SearchScope::CurrentFolder, } } } impl SearchFilter { /// Returns true if any filter is active. pub fn is_active(&self) -> bool { !self.text_query.is_empty() || self.bpm_min.is_some() || self.bpm_max.is_some() || !self.keys.is_empty() || self.duration_min.is_some() || self.duration_max.is_some() || self.peak_db_min.is_some() || self.peak_db_max.is_some() || !self.classifications.is_empty() || !self.required_tags.is_empty() } /// Count distinct active filter categories (excluding text query). pub fn active_count(&self) -> usize { let mut n = 0; if self.bpm_min.is_some() || self.bpm_max.is_some() { n += 1; } if self.duration_min.is_some() || self.duration_max.is_some() { n += 1; } if self.peak_db_min.is_some() || self.peak_db_max.is_some() { n += 1; } if !self.classifications.is_empty() { n += 1; } if !self.keys.is_empty() { n += 1; } if !self.required_tags.is_empty() { n += 1; } n } /// Clear all filter criteria. pub fn clear(&mut self) { *self = Self::default(); } /// Generate a human-readable description of active filters for use as a default collection name. pub fn describe(&self) -> String { let mut parts = Vec::new(); if !self.text_query.is_empty() { parts.push(format!("\"{}\"", self.text_query)); } if let (Some(min), Some(max)) = (self.bpm_min, self.bpm_max) { parts.push(format!("BPM {:.0}-{:.0}", min, max)); } else if let Some(min) = self.bpm_min { parts.push(format!("BPM {:.0}+", min)); } else if let Some(max) = self.bpm_max { parts.push(format!("BPM <{:.0}", max)); } if let (Some(min), Some(max)) = (self.duration_min, self.duration_max) { parts.push(format!("{:.1}-{:.1}s", min, max)); } else if let Some(min) = self.duration_min { parts.push(format!(">{:.1}s", min)); } else if let Some(max) = self.duration_max { parts.push(format!("<{:.1}s", max)); } if let (Some(min), Some(max)) = (self.peak_db_min, self.peak_db_max) { parts.push(format!("{:.0} to {:.0} dB", min, max)); } else if let Some(min) = self.peak_db_min { parts.push(format!(">{:.0} dB", min)); } else if let Some(max) = self.peak_db_max { parts.push(format!("<{:.0} dB", max)); } if !self.classifications.is_empty() { parts.push(self.classifications.join(", ")); } if !self.keys.is_empty() { parts.push(self.keys.join(", ")); } if !self.required_tags.is_empty() { parts.push(self.required_tags.iter().map(|t| format!("#{t}")).collect::>().join(" ")); } if parts.is_empty() { "All samples".to_string() } else { parts.join(" | ") } } } /// Search within a specific VFS folder. #[instrument(skip_all)] pub fn search_in_folder( db: &Database, filter: &SearchFilter, vfs_id: VfsId, parent_id: Option, ) -> Result> { let mut sql = String::from( "SELECT n.id, n.vfs_id, n.parent_id, n.name, n.node_type, n.sample_hash, n.created_at, a.bpm, a.musical_key, COALESCE(a.duration, s.duration), a.classification, a.peak_db, a.is_loop, s.cloud_only FROM vfs_nodes n LEFT JOIN audio_analysis a ON n.sample_hash = a.hash LEFT JOIN samples s ON n.sample_hash = s.hash WHERE n.vfs_id = ?1 AND n.parent_id IS ?2 AND s.deleted_at IS NULL", ); let mut params: Vec> = Vec::new(); params.push(Box::new(vfs_id)); params.push(Box::new(parent_id)); append_filter_clauses(&mut sql, &mut params, filter); sql.push_str(" ORDER BY n.node_type ASC, n.name ASC"); let mut stmt = db.conn().prepare(&sql)?; let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect(); let rows = stmt.query_map(param_refs.as_slice(), vfs_mod::map_enriched_row)?; Ok(rows.collect::, _>>()?) } /// Search globally across all VFS roots. #[instrument(skip_all)] pub fn search_global( db: &Database, filter: &SearchFilter, ) -> Result> { let mut sql = String::from( "SELECT n.id, n.vfs_id, n.parent_id, n.name, n.node_type, n.sample_hash, n.created_at, a.bpm, a.musical_key, COALESCE(a.duration, s.duration), a.classification, a.peak_db, a.is_loop, s.cloud_only FROM vfs_nodes n LEFT JOIN audio_analysis a ON n.sample_hash = a.hash LEFT JOIN samples s ON n.sample_hash = s.hash WHERE s.deleted_at IS NULL", ); let mut params: Vec> = Vec::new(); append_filter_clauses(&mut sql, &mut params, filter); sql.push_str(" ORDER BY n.node_type ASC, n.name ASC LIMIT 500"); let mut stmt = db.conn().prepare(&sql)?; let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect(); let rows = stmt.query_map(param_refs.as_slice(), vfs_mod::map_enriched_row)?; Ok(rows.collect::, _>>()?) } /// Return the set of musically compatible keys for a given key. /// /// Uses the circle of fifths: for a major key, returns the key itself, its relative minor, /// and the two adjacent major/minor pairs. For a minor key, finds the relative major first, /// then expands from there. Unknown keys return just themselves. pub fn compatible_keys(key: &str) -> Vec { // Circle of fifths for major keys (and their relative minors). // Each entry: (major, relative_minor) const CIRCLE: &[(&str, &str)] = &[ ("C major", "A minor"), ("G major", "E minor"), ("D major", "B minor"), ("A major", "F# minor"), ("E major", "C# minor"), ("B major", "G# minor"), ("F# major", "D# minor"), ("C# major", "A# minor"), ("G# major", "F minor"), // enharmonic: Ab = G# ("D# major", "C minor"), // enharmonic: Eb = D# ("A# major", "G minor"), // enharmonic: Bb = A# ("F major", "D minor"), ]; // Find position on the circle for a major key let find_major_pos = |major: &str| -> Option { CIRCLE.iter().position(|(m, _)| *m == major) }; // Find which major key this key belongs to (either it IS a major key, or it's the relative minor) let major_key = if key.ends_with("major") { key.to_string() } else if key.ends_with("minor") { match CIRCLE.iter().find(|(_, m)| *m == key) { Some((maj, _)) => maj.to_string(), None => return vec![key.to_string()], } } else { return vec![key.to_string()]; }; let pos = match find_major_pos(&major_key) { Some(p) => p, None => return vec![key.to_string()], }; let len = CIRCLE.len(); let left = (pos + len - 1) % len; let right = (pos + 1) % len; vec![ CIRCLE[pos].0.to_string(), CIRCLE[pos].1.to_string(), CIRCLE[right].0.to_string(), CIRCLE[right].1.to_string(), CIRCLE[left].0.to_string(), CIRCLE[left].1.to_string(), ] } fn append_filter_clauses( sql: &mut String, params: &mut Vec>, filter: &SearchFilter, ) { if !filter.text_query.is_empty() { let pattern = format!("%{}%", tagtree::escape_like(&filter.text_query)); sql.push_str(&format!( " AND (n.name LIKE ?{idx} ESCAPE '\\')", idx = params.len() + 1 )); params.push(Box::new(pattern)); } if let Some(bpm_min) = filter.bpm_min { sql.push_str(&format!(" AND a.bpm >= ?{}", params.len() + 1)); params.push(Box::new(bpm_min)); } if let Some(bpm_max) = filter.bpm_max { sql.push_str(&format!(" AND a.bpm <= ?{}", params.len() + 1)); params.push(Box::new(bpm_max)); } if let Some(dur_min) = filter.duration_min { sql.push_str(&format!(" AND COALESCE(a.duration, s.duration) >= ?{}", params.len() + 1)); params.push(Box::new(dur_min)); } if let Some(dur_max) = filter.duration_max { sql.push_str(&format!(" AND COALESCE(a.duration, s.duration) <= ?{}", params.len() + 1)); params.push(Box::new(dur_max)); } if let Some(peak_min) = filter.peak_db_min { sql.push_str(&format!(" AND a.peak_db >= ?{}", params.len() + 1)); params.push(Box::new(peak_min)); } if let Some(peak_max) = filter.peak_db_max { sql.push_str(&format!(" AND a.peak_db <= ?{}", params.len() + 1)); params.push(Box::new(peak_max)); } if !filter.classifications.is_empty() { let placeholders: Vec = filter .classifications .iter() .enumerate() .map(|(i, _)| format!("?{}", params.len() + 1 + i)) .collect(); sql.push_str(&format!( " AND a.classification IN ({})", placeholders.join(", ") )); for class in &filter.classifications { params.push(Box::new(class.clone())); } } if !filter.keys.is_empty() { use std::collections::HashSet; let expanded: Vec = match filter.key_mode { KeyFilterMode::Compatible => { let mut set = HashSet::new(); for key in &filter.keys { for compat in compatible_keys(key) { set.insert(compat); } } set.into_iter().collect() } KeyFilterMode::Exact => filter.keys.clone(), }; let placeholders: Vec = expanded .iter() .enumerate() .map(|(i, _)| format!("?{}", params.len() + 1 + i)) .collect(); sql.push_str(&format!( " AND a.musical_key IN ({})", placeholders.join(", ") )); for key in &expanded { params.push(Box::new(key.clone())); } } // Tag filter: no input validation needed here. Tags stored in the `tags` table are // already validated by `validate_tag` on insert (lowercase alphanumeric + dots/hyphens // only), and `escape_like` prevents SQL injection via the LIKE pattern. Filtering // against validated data with escaped patterns is defense-in-depth enough. for tag in &filter.required_tags { sql.push_str(&format!( " AND EXISTS (SELECT 1 FROM tags WHERE sample_hash = n.sample_hash AND tag LIKE ?{} ESCAPE '\\')", params.len() + 1 )); let pattern = format!("{}%", tagtree::escape_like(tag)); params.push(Box::new(pattern)); } } #[cfg(test)] mod tests { use super::*; use crate::test_helpers::{insert_fake_sample, insert_sample_with_analysis}; use crate::{tags, vfs}; fn setup_with_samples() -> (Database, VfsId) { let db = Database::open_in_memory().unwrap(); insert_fake_sample(&db, "hash1"); insert_fake_sample(&db, "hash2"); let vfs_id = vfs::create_vfs(&db, "Test").unwrap(); vfs::create_sample_link(&db, vfs_id, None, "kick.wav", "hash1").unwrap(); vfs::create_sample_link(&db, vfs_id, None, "snare.wav", "hash2").unwrap(); (db, vfs_id) } /// Set up a VFS with two analysed samples for filter tests. fn setup_with_analysis() -> (Database, VfsId) { let db = Database::open_in_memory().unwrap(); let vfs_id = vfs::create_vfs(&db, "Test").unwrap(); insert_sample_with_analysis( &db, "fast", "kick_130.wav", vfs_id, Some(130.0), Some("C major"), Some(0.5), Some("kick"), ); insert_sample_with_analysis( &db, "slow", "pad_90.wav", vfs_id, Some(90.0), Some("A minor"), Some(5.0), Some("pad"), ); (db, vfs_id) } #[test] fn empty_filter_returns_all() { let (db, vfs_id) = setup_with_samples(); let filter = SearchFilter::default(); let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 2); } #[test] fn text_filter_matches_name() { let (db, vfs_id) = setup_with_samples(); let filter = SearchFilter { text_query: "kick".to_string(), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick.wav"); } #[test] fn text_filter_no_match() { let (db, vfs_id) = setup_with_samples(); let filter = SearchFilter { text_query: "nonexistent".to_string(), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 0); } #[test] fn filter_is_active() { let mut f = SearchFilter::default(); assert!(!f.is_active()); f.text_query = "test".to_string(); assert!(f.is_active()); f.clear(); assert!(!f.is_active()); } #[test] fn bpm_min_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { bpm_min: Some(120.0), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick_130.wav"); } #[test] fn bpm_max_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { bpm_max: Some(100.0), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "pad_90.wav"); } #[test] fn bpm_range_filter() { let (db, vfs_id) = setup_with_analysis(); // Range that includes both let filter = SearchFilter { bpm_min: Some(80.0), bpm_max: Some(140.0), ..Default::default() }; assert_eq!(search_in_folder(&db, &filter, vfs_id, None).unwrap().len(), 2); // Range that excludes both let filter = SearchFilter { bpm_min: Some(95.0), bpm_max: Some(125.0), ..Default::default() }; assert_eq!(search_in_folder(&db, &filter, vfs_id, None).unwrap().len(), 0); } #[test] fn duration_min_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { duration_min: Some(2.0), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "pad_90.wav"); } #[test] fn duration_max_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { duration_max: Some(1.0), ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick_130.wav"); } #[test] fn key_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { keys: vec!["A minor".to_string()], ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "pad_90.wav"); } #[test] fn classification_filter() { let (db, vfs_id) = setup_with_analysis(); let filter = SearchFilter { classifications: vec!["kick".to_string()], ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick_130.wav"); } #[test] fn tag_filter() { let (db, vfs_id) = setup_with_analysis(); tags::add_tag(&db, "fast", "drums.kick").unwrap(); let filter = SearchFilter { required_tags: vec!["drums".to_string()], ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick_130.wav"); } #[test] fn combined_filters() { let (db, vfs_id) = setup_with_analysis(); // Text + BPM + classification — all match the kick sample let filter = SearchFilter { text_query: "kick".to_string(), bpm_min: Some(120.0), classifications: vec!["kick".to_string()], ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick_130.wav"); // Same filters but text doesn't match — AND logic means zero results let filter = SearchFilter { text_query: "pad".to_string(), bpm_min: Some(120.0), classifications: vec!["kick".to_string()], ..Default::default() }; assert_eq!(search_in_folder(&db, &filter, vfs_id, None).unwrap().len(), 0); } #[test] fn global_search() { let db = Database::open_in_memory().unwrap(); let vfs1 = vfs::create_vfs(&db, "VFS1").unwrap(); let vfs2 = vfs::create_vfs(&db, "VFS2").unwrap(); insert_sample_with_analysis( &db, "h1", "kick.wav", vfs1, Some(128.0), None, None, Some("kick"), ); insert_sample_with_analysis( &db, "h2", "snare.wav", vfs2, Some(128.0), None, None, Some("snare"), ); // Global search finds samples across VFS roots let filter = SearchFilter { classifications: vec!["kick".to_string()], scope: SearchScope::Global, ..Default::default() }; let results = search_global(&db, &filter).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "kick.wav"); // No filter returns all let filter = SearchFilter { scope: SearchScope::Global, ..Default::default() }; assert_eq!(search_global(&db, &filter).unwrap().len(), 2); } // --- Key compatibility tests --- #[test] fn compatible_keys_major() { let compat = compatible_keys("C major"); assert!(compat.contains(&"C major".to_string())); assert!(compat.contains(&"A minor".to_string())); // relative minor assert!(compat.contains(&"G major".to_string())); // adjacent right assert!(compat.contains(&"E minor".to_string())); assert!(compat.contains(&"F major".to_string())); // adjacent left assert!(compat.contains(&"D minor".to_string())); assert_eq!(compat.len(), 6); } #[test] fn compatible_keys_minor() { // A minor's relative major is C major, so results should match C major's expansion let compat = compatible_keys("A minor"); assert!(compat.contains(&"C major".to_string())); assert!(compat.contains(&"A minor".to_string())); assert!(compat.contains(&"G major".to_string())); assert!(compat.contains(&"F major".to_string())); assert_eq!(compat.len(), 6); } #[test] fn compatible_keys_unknown() { let compat = compatible_keys("X dorian"); assert_eq!(compat, vec!["X dorian".to_string()]); } #[test] fn compatible_key_filter_expands() { let db = Database::open_in_memory().unwrap(); let vfs_id = vfs::create_vfs(&db, "Test").unwrap(); insert_sample_with_analysis(&db, "c", "c.wav", vfs_id, None, Some("C major"), None, None); insert_sample_with_analysis(&db, "am", "am.wav", vfs_id, None, Some("A minor"), None, None); insert_sample_with_analysis(&db, "g", "g.wav", vfs_id, None, Some("G major"), None, None); insert_sample_with_analysis(&db, "eb", "eb.wav", vfs_id, None, Some("D# major"), None, None); // Compatible mode: selecting "C major" should match C major, A minor, G major (and E minor, F major, D minor) let filter = SearchFilter { keys: vec!["C major".to_string()], key_mode: KeyFilterMode::Compatible, ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); // Should find c.wav, am.wav, g.wav (D# major is not compatible) assert_eq!(results.len(), 3); } #[test] fn exact_key_mode_unchanged() { let db = Database::open_in_memory().unwrap(); let vfs_id = vfs::create_vfs(&db, "Test").unwrap(); insert_sample_with_analysis(&db, "c", "c.wav", vfs_id, None, Some("C major"), None, None); insert_sample_with_analysis(&db, "am", "am.wav", vfs_id, None, Some("A minor"), None, None); let filter = SearchFilter { keys: vec!["C major".to_string()], key_mode: KeyFilterMode::Exact, ..Default::default() }; let results = search_in_folder(&db, &filter, vfs_id, None).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].node.name, "c.wav"); } }