Skip to main content

max / audiofiles

11.8 KB · 342 lines History Blame Raw
1 //! Classification validation test against labeled drum machine samples.
2 //!
3 //! Reads pre-labeled samples from `~/Git/Drums/test_data/{kick,snare,hihat,...}/`
4 //! and checks what percentage the two-layer ML classifier gets right.
5 //!
6 //! Label-to-SampleClass mapping:
7 //! kick -> Kick
8 //! snare -> Snare
9 //! hihat -> HiHat
10 //! cymbal -> Cymbal
11 //! clap -> Percussion (claps are percussive hits)
12 //! tom -> Percussion (toms are pitched percussion)
13 //! percussion -> Percussion (congas, cowbell, shaker, etc.)
14 //! fx -> (excluded from validation — catch-all junk drawer, not a learnable class)
15
16 use std::collections::HashMap;
17 use std::path::PathBuf;
18
19 use audiofiles_core::analysis::classify::SampleClass;
20 use audiofiles_core::analysis::config::AnalysisConfig;
21 use audiofiles_core::analysis::{self};
22
23 /// Map a test data folder name to the expected SampleClass.
24 fn expected_class(label: &str) -> Option<SampleClass> {
25 match label {
26 "kick" => Some(SampleClass::Kick),
27 "snare" => Some(SampleClass::Snare),
28 "hihat" => Some(SampleClass::HiHat),
29 "cymbal" => Some(SampleClass::Cymbal),
30 "clap" | "tom" | "percussion" => Some(SampleClass::Percussion),
31 "fx" => None, // excluded — misc catch-all, not a classifiable category
32 _ => None,
33 }
34 }
35
36 /// Acceptable alternate classifications for a given label.
37 /// Some boundary cases are legitimately ambiguous.
38 fn acceptable_alternates(label: &str) -> &[SampleClass] {
39 match label {
40 "clap" => &[SampleClass::Snare, SampleClass::Impact],
41 "tom" => &[SampleClass::Kick],
42 "hihat" => &[SampleClass::Cymbal, SampleClass::Noise],
43 "cymbal" => &[SampleClass::HiHat, SampleClass::Noise],
44 "snare" => &[SampleClass::Percussion, SampleClass::Impact],
45 "kick" => &[SampleClass::Impact, SampleClass::Bass],
46 "fx" => &[], // excluded from validation
47 "percussion" => &[
48 SampleClass::HiHat,
49 SampleClass::Cymbal,
50 SampleClass::Noise,
51 SampleClass::Impact,
52 SampleClass::Foley,
53 ],
54 _ => &[],
55 }
56 }
57
58 struct ClassResults {
59 total: usize,
60 correct: usize,
61 acceptable: usize,
62 misclassified: Vec<(String, SampleClass)>,
63 confidences: Vec<f64>,
64 }
65
66 /// Run classification on all labeled samples and report accuracy.
67 #[test]
68 fn classify_drum_samples() {
69 let test_data =
70 PathBuf::from(std::env::var("HOME").expect("HOME not set")).join("Git/Drums/test_data");
71
72 if !test_data.exists() {
73 eprintln!(
74 "Skipping classify_drum_samples: test data not found at {}",
75 test_data.display()
76 );
77 return;
78 }
79
80 let config = AnalysisConfig {
81 loudness: true,
82 spectral: true,
83 bpm: false,
84 key: false,
85 loop_detect: false,
86 classify: true,
87 fingerprint: false,
88 auto_suggest_tags: false,
89 max_analysis_seconds: Some(5.0),
90 smart_skip: false,
91 };
92
93 let labels = [
94 "kick",
95 "snare",
96 "hihat",
97 "cymbal",
98 "clap",
99 "tom",
100 "percussion",
101 "fx",
102 ];
103 let mut results: HashMap<String, ClassResults> = HashMap::new();
104
105 for label in &labels {
106 let dir = test_data.join(label);
107 if !dir.exists() {
108 continue;
109 }
110
111 let expected = match expected_class(label) {
112 Some(c) => c,
113 None => continue,
114 };
115 let alts = acceptable_alternates(label);
116
117 let mut class_results = ClassResults {
118 total: 0,
119 correct: 0,
120 acceptable: 0,
121 misclassified: Vec::new(),
122 confidences: Vec::new(),
123 };
124
125 let mut entries: Vec<_> = std::fs::read_dir(&dir)
126 .expect("failed to read test data dir")
127 .filter_map(|e| e.ok())
128 .filter(|e| {
129 let name = e.file_name().to_string_lossy().to_lowercase();
130 name.ends_with(".wav") || name.ends_with(".aif") || name.ends_with(".aiff")
131 })
132 .collect();
133 entries.sort_by_key(|e| e.file_name());
134
135 for entry in &entries {
136 let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path());
137 if !path.exists() {
138 continue;
139 }
140
141 let fake_hash = format!("test_{:016x}", class_results.total);
142 match analysis::analyze_sample(&fake_hash, &path, &config) {
143 Ok(result) => {
144 class_results.total += 1;
145 if let Some(conf) = result.classification_confidence {
146 class_results.confidences.push(conf);
147 }
148 if let Some(class) = result.classification {
149 if class == expected {
150 class_results.correct += 1;
151 } else if alts.contains(&class) {
152 class_results.acceptable += 1;
153 } else {
154 class_results.misclassified.push((
155 entry.file_name().to_string_lossy().to_string(),
156 class,
157 ));
158 }
159 } else {
160 class_results.misclassified.push((
161 entry.file_name().to_string_lossy().to_string(),
162 SampleClass::Misc,
163 ));
164 }
165 }
166 Err(e) => {
167 eprintln!(
168 " Failed to analyze {}: {}",
169 entry.file_name().to_string_lossy(),
170 e
171 );
172 }
173 }
174 }
175
176 results.insert(label.to_string(), class_results);
177 }
178
179 // Print report
180 println!("\n============================================================");
181 println!("CLASSIFICATION VALIDATION REPORT (Two-Layer ML)");
182 println!("============================================================\n");
183
184 let mut grand_total = 0;
185 let mut grand_correct = 0;
186 let mut grand_acceptable = 0;
187 let mut all_confidences = Vec::new();
188
189 for label in &labels {
190 if let Some(r) = results.get(*label) {
191 if r.total == 0 {
192 continue;
193 }
194 let strict_pct = (r.correct as f64 / r.total as f64) * 100.0;
195 let lenient_pct = ((r.correct + r.acceptable) as f64 / r.total as f64) * 100.0;
196
197 // Confidence stats
198 let (conf_mean, conf_median, conf_p10) = if r.confidences.is_empty() {
199 (0.0, 0.0, 0.0)
200 } else {
201 let mut sorted = r.confidences.clone();
202 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
203 let mean = sorted.iter().sum::<f64>() / sorted.len() as f64;
204 let median = sorted[sorted.len() / 2];
205 let p10_idx = sorted.len() / 10;
206 let p10 = sorted[p10_idx];
207 (mean, median, p10)
208 };
209
210 println!(
211 "{:<12} {:>4} samples | strict: {:>5.1}% ({}/{}) | lenient: {:>5.1}% ({}/{})",
212 label,
213 r.total,
214 strict_pct,
215 r.correct,
216 r.total,
217 lenient_pct,
218 r.correct + r.acceptable,
219 r.total
220 );
221 println!(
222 " confidence: mean={:.2} median={:.2} P10={:.2}",
223 conf_mean, conf_median, conf_p10
224 );
225
226 // Show up to 10 misclassifications per class
227 if !r.misclassified.is_empty() {
228 let show = r.misclassified.len().min(10);
229 for (name, got) in &r.misclassified[..show] {
230 println!(" MISS: {} -> {}", name, got.as_str());
231 }
232 if r.misclassified.len() > 10 {
233 println!(
234 " ... and {} more misclassifications",
235 r.misclassified.len() - 10
236 );
237 }
238 }
239 println!();
240
241 grand_total += r.total;
242 grand_correct += r.correct;
243 grand_acceptable += r.acceptable;
244 all_confidences.extend_from_slice(&r.confidences);
245 }
246 }
247
248 let overall_strict = (grand_correct as f64 / grand_total as f64) * 100.0;
249 let overall_lenient =
250 ((grand_correct + grand_acceptable) as f64 / grand_total as f64) * 100.0;
251
252 println!("============================================================");
253 println!(
254 "OVERALL {:>4} samples | strict: {:>5.1}% | lenient: {:>5.1}%",
255 grand_total, overall_strict, overall_lenient
256 );
257
258 if !all_confidences.is_empty() {
259 let mut sorted = all_confidences.clone();
260 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
261 let mean = sorted.iter().sum::<f64>() / sorted.len() as f64;
262 let median = sorted[sorted.len() / 2];
263 println!(
264 " confidence: mean={:.2} median={:.2}",
265 mean, median
266 );
267 }
268 println!("============================================================");
269
270 // Confusion matrix summary
271 println!("\nCONFUSION SUMMARY (misclassifications by target class):");
272 for label in &labels {
273 if let Some(r) = results.get(*label) {
274 if r.misclassified.is_empty() {
275 continue;
276 }
277 let mut confusion: HashMap<&str, usize> = HashMap::new();
278 for (_, got) in &r.misclassified {
279 *confusion.entry(got.as_str()).or_default() += 1;
280 }
281 let mut sorted: Vec<_> = confusion.into_iter().collect();
282 sorted.sort_by_key(|x| std::cmp::Reverse(x.1));
283 let items: Vec<String> = sorted.iter().map(|(c, n)| format!("{}:{}", c, n)).collect();
284 println!(" {:<12} -> {}", label, items.join(", "));
285 }
286 }
287
288 // Per-class precision/recall/F1
289 println!("\nPER-CLASS METRICS:");
290 println!("{:<12} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1");
291 let drum_labels = ["kick", "snare", "hihat", "cymbal", "percussion"];
292 for target_label in &drum_labels {
293 let target_class = expected_class(target_label).unwrap();
294
295 // True positives: samples of this class predicted correctly
296 let tp = results
297 .get(*target_label)
298 .map(|r| r.correct)
299 .unwrap_or(0) as f64;
300
301 // False negatives: samples of this class predicted as something else
302 let fn_ = results
303 .get(*target_label)
304 .map(|r| r.misclassified.len())
305 .unwrap_or(0) as f64;
306
307 // False positives: other classes predicted as this class
308 let fp: f64 = labels
309 .iter()
310 .filter(|&&l| expected_class(l) != Some(target_class))
311 .filter_map(|l| results.get(*l))
312 .flat_map(|r| r.misclassified.iter())
313 .filter(|(_, got)| *got == target_class)
314 .count() as f64;
315
316 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
317 let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
318 let f1 = if precision + recall > 0.0 {
319 2.0 * precision * recall / (precision + recall)
320 } else {
321 0.0
322 };
323
324 println!(
325 "{:<12} {:>8.1}% {:>8.1}% {:>8.1}%",
326 target_label,
327 precision * 100.0,
328 recall * 100.0,
329 f1 * 100.0
330 );
331 }
332
333 // Assert minimum accuracy
334 // With trained RF model: target >= 90% strict
335 // Without model (rule-based fallback): maintain >= 40% lenient
336 assert!(
337 overall_lenient >= 40.0,
338 "Overall lenient accuracy {:.1}% is below minimum threshold of 40%",
339 overall_lenient
340 );
341 }
342