//! Classification validation test against labeled drum machine samples. //! //! Reads pre-labeled samples from `~/Git/Drums/test_data/{kick,snare,hihat,...}/` //! and checks what percentage the two-layer ML classifier gets right. //! //! Label-to-SampleClass mapping: //! kick -> Kick //! snare -> Snare //! hihat -> HiHat //! cymbal -> Cymbal //! clap -> Percussion (claps are percussive hits) //! tom -> Percussion (toms are pitched percussion) //! percussion -> Percussion (congas, cowbell, shaker, etc.) //! fx -> (excluded from validation — catch-all junk drawer, not a learnable class) use std::collections::HashMap; use std::path::PathBuf; use audiofiles_core::analysis::classify::SampleClass; use audiofiles_core::analysis::config::AnalysisConfig; use audiofiles_core::analysis::{self}; /// Map a test data folder name to the expected SampleClass. fn expected_class(label: &str) -> Option { match label { "kick" => Some(SampleClass::Kick), "snare" => Some(SampleClass::Snare), "hihat" => Some(SampleClass::HiHat), "cymbal" => Some(SampleClass::Cymbal), "clap" | "tom" | "percussion" => Some(SampleClass::Percussion), "fx" => None, // excluded — misc catch-all, not a classifiable category _ => None, } } /// Acceptable alternate classifications for a given label. /// Some boundary cases are legitimately ambiguous. fn acceptable_alternates(label: &str) -> &[SampleClass] { match label { "clap" => &[SampleClass::Snare, SampleClass::Impact], "tom" => &[SampleClass::Kick], "hihat" => &[SampleClass::Cymbal, SampleClass::Noise], "cymbal" => &[SampleClass::HiHat, SampleClass::Noise], "snare" => &[SampleClass::Percussion, SampleClass::Impact], "kick" => &[SampleClass::Impact, SampleClass::Bass], "fx" => &[], // excluded from validation "percussion" => &[ SampleClass::HiHat, SampleClass::Cymbal, SampleClass::Noise, SampleClass::Impact, SampleClass::Foley, ], _ => &[], } } struct ClassResults { total: usize, correct: usize, acceptable: usize, misclassified: Vec<(String, SampleClass)>, confidences: Vec, } /// Run classification on all labeled samples and report accuracy. #[test] fn classify_drum_samples() { let test_data = PathBuf::from(std::env::var("HOME").expect("HOME not set")).join("Git/Drums/test_data"); if !test_data.exists() { eprintln!( "Skipping classify_drum_samples: test data not found at {}", test_data.display() ); return; } let config = AnalysisConfig { loudness: true, spectral: true, bpm: false, key: false, loop_detect: false, classify: true, fingerprint: false, auto_suggest_tags: false, max_analysis_seconds: Some(5.0), smart_skip: false, }; let labels = [ "kick", "snare", "hihat", "cymbal", "clap", "tom", "percussion", "fx", ]; let mut results: HashMap = HashMap::new(); for label in &labels { let dir = test_data.join(label); if !dir.exists() { continue; } let expected = match expected_class(label) { Some(c) => c, None => continue, }; let alts = acceptable_alternates(label); let mut class_results = ClassResults { total: 0, correct: 0, acceptable: 0, misclassified: Vec::new(), confidences: Vec::new(), }; let mut entries: Vec<_> = std::fs::read_dir(&dir) .expect("failed to read test data dir") .filter_map(|e| e.ok()) .filter(|e| { let name = e.file_name().to_string_lossy().to_lowercase(); name.ends_with(".wav") || name.ends_with(".aif") || name.ends_with(".aiff") }) .collect(); entries.sort_by_key(|e| e.file_name()); for entry in &entries { let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path()); if !path.exists() { continue; } let fake_hash = format!("test_{:016x}", class_results.total); match analysis::analyze_sample(&fake_hash, &path, &config) { Ok(result) => { class_results.total += 1; if let Some(conf) = result.classification_confidence { class_results.confidences.push(conf); } if let Some(class) = result.classification { if class == expected { class_results.correct += 1; } else if alts.contains(&class) { class_results.acceptable += 1; } else { class_results.misclassified.push(( entry.file_name().to_string_lossy().to_string(), class, )); } } else { class_results.misclassified.push(( entry.file_name().to_string_lossy().to_string(), SampleClass::Misc, )); } } Err(e) => { eprintln!( " Failed to analyze {}: {}", entry.file_name().to_string_lossy(), e ); } } } results.insert(label.to_string(), class_results); } // Print report println!("\n============================================================"); println!("CLASSIFICATION VALIDATION REPORT (Two-Layer ML)"); println!("============================================================\n"); let mut grand_total = 0; let mut grand_correct = 0; let mut grand_acceptable = 0; let mut all_confidences = Vec::new(); for label in &labels { if let Some(r) = results.get(*label) { if r.total == 0 { continue; } let strict_pct = (r.correct as f64 / r.total as f64) * 100.0; let lenient_pct = ((r.correct + r.acceptable) as f64 / r.total as f64) * 100.0; // Confidence stats let (conf_mean, conf_median, conf_p10) = if r.confidences.is_empty() { (0.0, 0.0, 0.0) } else { let mut sorted = r.confidences.clone(); sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); let mean = sorted.iter().sum::() / sorted.len() as f64; let median = sorted[sorted.len() / 2]; let p10_idx = sorted.len() / 10; let p10 = sorted[p10_idx]; (mean, median, p10) }; println!( "{:<12} {:>4} samples | strict: {:>5.1}% ({}/{}) | lenient: {:>5.1}% ({}/{})", label, r.total, strict_pct, r.correct, r.total, lenient_pct, r.correct + r.acceptable, r.total ); println!( " confidence: mean={:.2} median={:.2} P10={:.2}", conf_mean, conf_median, conf_p10 ); // Show up to 10 misclassifications per class if !r.misclassified.is_empty() { let show = r.misclassified.len().min(10); for (name, got) in &r.misclassified[..show] { println!(" MISS: {} -> {}", name, got.as_str()); } if r.misclassified.len() > 10 { println!( " ... and {} more misclassifications", r.misclassified.len() - 10 ); } } println!(); grand_total += r.total; grand_correct += r.correct; grand_acceptable += r.acceptable; all_confidences.extend_from_slice(&r.confidences); } } let overall_strict = (grand_correct as f64 / grand_total as f64) * 100.0; let overall_lenient = ((grand_correct + grand_acceptable) as f64 / grand_total as f64) * 100.0; println!("============================================================"); println!( "OVERALL {:>4} samples | strict: {:>5.1}% | lenient: {:>5.1}%", grand_total, overall_strict, overall_lenient ); if !all_confidences.is_empty() { let mut sorted = all_confidences.clone(); sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); let mean = sorted.iter().sum::() / sorted.len() as f64; let median = sorted[sorted.len() / 2]; println!( " confidence: mean={:.2} median={:.2}", mean, median ); } println!("============================================================"); // Confusion matrix summary println!("\nCONFUSION SUMMARY (misclassifications by target class):"); for label in &labels { if let Some(r) = results.get(*label) { if r.misclassified.is_empty() { continue; } let mut confusion: HashMap<&str, usize> = HashMap::new(); for (_, got) in &r.misclassified { *confusion.entry(got.as_str()).or_default() += 1; } let mut sorted: Vec<_> = confusion.into_iter().collect(); sorted.sort_by_key(|x| std::cmp::Reverse(x.1)); let items: Vec = sorted.iter().map(|(c, n)| format!("{}:{}", c, n)).collect(); println!(" {:<12} -> {}", label, items.join(", ")); } } // Per-class precision/recall/F1 println!("\nPER-CLASS METRICS:"); println!("{:<12} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1"); let drum_labels = ["kick", "snare", "hihat", "cymbal", "percussion"]; for target_label in &drum_labels { let target_class = expected_class(target_label).unwrap(); // True positives: samples of this class predicted correctly let tp = results .get(*target_label) .map(|r| r.correct) .unwrap_or(0) as f64; // False negatives: samples of this class predicted as something else let fn_ = results .get(*target_label) .map(|r| r.misclassified.len()) .unwrap_or(0) as f64; // False positives: other classes predicted as this class let fp: f64 = labels .iter() .filter(|&&l| expected_class(l) != Some(target_class)) .filter_map(|l| results.get(*l)) .flat_map(|r| r.misclassified.iter()) .filter(|(_, got)| *got == target_class) .count() as f64; let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 }; let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 }; let f1 = if precision + recall > 0.0 { 2.0 * precision * recall / (precision + recall) } else { 0.0 }; println!( "{:<12} {:>8.1}% {:>8.1}% {:>8.1}%", target_label, precision * 100.0, recall * 100.0, f1 * 100.0 ); } // Assert minimum accuracy // With trained RF model: target >= 90% strict // Without model (rule-based fallback): maintain >= 40% lenient assert!( overall_lenient >= 40.0, "Overall lenient accuracy {:.1}% is below minimum threshold of 40%", overall_lenient ); }