| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
|
| 16 |
use std::collections::HashMap; |
| 17 |
use std::path::PathBuf; |
| 18 |
|
| 19 |
use audiofiles_core::analysis::classify::SampleClass; |
| 20 |
use audiofiles_core::analysis::config::AnalysisConfig; |
| 21 |
use audiofiles_core::analysis::{self}; |
| 22 |
|
| 23 |
|
| 24 |
fn expected_class(label: &str) -> Option<SampleClass> { |
| 25 |
match label { |
| 26 |
"kick" => Some(SampleClass::Kick), |
| 27 |
"snare" => Some(SampleClass::Snare), |
| 28 |
"hihat" => Some(SampleClass::HiHat), |
| 29 |
"cymbal" => Some(SampleClass::Cymbal), |
| 30 |
"clap" | "tom" | "percussion" => Some(SampleClass::Percussion), |
| 31 |
"fx" => None, |
| 32 |
_ => None, |
| 33 |
} |
| 34 |
} |
| 35 |
|
| 36 |
|
| 37 |
|
| 38 |
fn acceptable_alternates(label: &str) -> &[SampleClass] { |
| 39 |
match label { |
| 40 |
"clap" => &[SampleClass::Snare, SampleClass::Impact], |
| 41 |
"tom" => &[SampleClass::Kick], |
| 42 |
"hihat" => &[SampleClass::Cymbal, SampleClass::Noise], |
| 43 |
"cymbal" => &[SampleClass::HiHat, SampleClass::Noise], |
| 44 |
"snare" => &[SampleClass::Percussion, SampleClass::Impact], |
| 45 |
"kick" => &[SampleClass::Impact, SampleClass::Bass], |
| 46 |
"fx" => &[], |
| 47 |
"percussion" => &[ |
| 48 |
SampleClass::HiHat, |
| 49 |
SampleClass::Cymbal, |
| 50 |
SampleClass::Noise, |
| 51 |
SampleClass::Impact, |
| 52 |
SampleClass::Foley, |
| 53 |
], |
| 54 |
_ => &[], |
| 55 |
} |
| 56 |
} |
| 57 |
|
| 58 |
struct ClassResults { |
| 59 |
total: usize, |
| 60 |
correct: usize, |
| 61 |
acceptable: usize, |
| 62 |
misclassified: Vec<(String, SampleClass)>, |
| 63 |
confidences: Vec<f64>, |
| 64 |
} |
| 65 |
|
| 66 |
|
| 67 |
#[test] |
| 68 |
fn classify_drum_samples() { |
| 69 |
let test_data = |
| 70 |
PathBuf::from(std::env::var("HOME").expect("HOME not set")).join("Git/Drums/test_data"); |
| 71 |
|
| 72 |
if !test_data.exists() { |
| 73 |
eprintln!( |
| 74 |
"Skipping classify_drum_samples: test data not found at {}", |
| 75 |
test_data.display() |
| 76 |
); |
| 77 |
return; |
| 78 |
} |
| 79 |
|
| 80 |
let config = AnalysisConfig { |
| 81 |
loudness: true, |
| 82 |
spectral: true, |
| 83 |
bpm: false, |
| 84 |
key: false, |
| 85 |
loop_detect: false, |
| 86 |
classify: true, |
| 87 |
fingerprint: false, |
| 88 |
auto_suggest_tags: false, |
| 89 |
max_analysis_seconds: Some(5.0), |
| 90 |
smart_skip: false, |
| 91 |
}; |
| 92 |
|
| 93 |
let labels = [ |
| 94 |
"kick", |
| 95 |
"snare", |
| 96 |
"hihat", |
| 97 |
"cymbal", |
| 98 |
"clap", |
| 99 |
"tom", |
| 100 |
"percussion", |
| 101 |
"fx", |
| 102 |
]; |
| 103 |
let mut results: HashMap<String, ClassResults> = HashMap::new(); |
| 104 |
|
| 105 |
for label in &labels { |
| 106 |
let dir = test_data.join(label); |
| 107 |
if !dir.exists() { |
| 108 |
continue; |
| 109 |
} |
| 110 |
|
| 111 |
let expected = match expected_class(label) { |
| 112 |
Some(c) => c, |
| 113 |
None => continue, |
| 114 |
}; |
| 115 |
let alts = acceptable_alternates(label); |
| 116 |
|
| 117 |
let mut class_results = ClassResults { |
| 118 |
total: 0, |
| 119 |
correct: 0, |
| 120 |
acceptable: 0, |
| 121 |
misclassified: Vec::new(), |
| 122 |
confidences: Vec::new(), |
| 123 |
}; |
| 124 |
|
| 125 |
let mut entries: Vec<_> = std::fs::read_dir(&dir) |
| 126 |
.expect("failed to read test data dir") |
| 127 |
.filter_map(|e| e.ok()) |
| 128 |
.filter(|e| { |
| 129 |
let name = e.file_name().to_string_lossy().to_lowercase(); |
| 130 |
name.ends_with(".wav") || name.ends_with(".aif") || name.ends_with(".aiff") |
| 131 |
}) |
| 132 |
.collect(); |
| 133 |
entries.sort_by_key(|e| e.file_name()); |
| 134 |
|
| 135 |
for entry in &entries { |
| 136 |
let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path()); |
| 137 |
if !path.exists() { |
| 138 |
continue; |
| 139 |
} |
| 140 |
|
| 141 |
let fake_hash = format!("test_{:016x}", class_results.total); |
| 142 |
match analysis::analyze_sample(&fake_hash, &path, &config) { |
| 143 |
Ok(result) => { |
| 144 |
class_results.total += 1; |
| 145 |
if let Some(conf) = result.classification_confidence { |
| 146 |
class_results.confidences.push(conf); |
| 147 |
} |
| 148 |
if let Some(class) = result.classification { |
| 149 |
if class == expected { |
| 150 |
class_results.correct += 1; |
| 151 |
} else if alts.contains(&class) { |
| 152 |
class_results.acceptable += 1; |
| 153 |
} else { |
| 154 |
class_results.misclassified.push(( |
| 155 |
entry.file_name().to_string_lossy().to_string(), |
| 156 |
class, |
| 157 |
)); |
| 158 |
} |
| 159 |
} else { |
| 160 |
class_results.misclassified.push(( |
| 161 |
entry.file_name().to_string_lossy().to_string(), |
| 162 |
SampleClass::Misc, |
| 163 |
)); |
| 164 |
} |
| 165 |
} |
| 166 |
Err(e) => { |
| 167 |
eprintln!( |
| 168 |
" Failed to analyze {}: {}", |
| 169 |
entry.file_name().to_string_lossy(), |
| 170 |
e |
| 171 |
); |
| 172 |
} |
| 173 |
} |
| 174 |
} |
| 175 |
|
| 176 |
results.insert(label.to_string(), class_results); |
| 177 |
} |
| 178 |
|
| 179 |
|
| 180 |
println!("\n============================================================"); |
| 181 |
println!("CLASSIFICATION VALIDATION REPORT (Two-Layer ML)"); |
| 182 |
println!("============================================================\n"); |
| 183 |
|
| 184 |
let mut grand_total = 0; |
| 185 |
let mut grand_correct = 0; |
| 186 |
let mut grand_acceptable = 0; |
| 187 |
let mut all_confidences = Vec::new(); |
| 188 |
|
| 189 |
for label in &labels { |
| 190 |
if let Some(r) = results.get(*label) { |
| 191 |
if r.total == 0 { |
| 192 |
continue; |
| 193 |
} |
| 194 |
let strict_pct = (r.correct as f64 / r.total as f64) * 100.0; |
| 195 |
let lenient_pct = ((r.correct + r.acceptable) as f64 / r.total as f64) * 100.0; |
| 196 |
|
| 197 |
|
| 198 |
let (conf_mean, conf_median, conf_p10) = if r.confidences.is_empty() { |
| 199 |
(0.0, 0.0, 0.0) |
| 200 |
} else { |
| 201 |
let mut sorted = r.confidences.clone(); |
| 202 |
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); |
| 203 |
let mean = sorted.iter().sum::<f64>() / sorted.len() as f64; |
| 204 |
let median = sorted[sorted.len() / 2]; |
| 205 |
let p10_idx = sorted.len() / 10; |
| 206 |
let p10 = sorted[p10_idx]; |
| 207 |
(mean, median, p10) |
| 208 |
}; |
| 209 |
|
| 210 |
println!( |
| 211 |
"{:<12} {:>4} samples | strict: {:>5.1}% ({}/{}) | lenient: {:>5.1}% ({}/{})", |
| 212 |
label, |
| 213 |
r.total, |
| 214 |
strict_pct, |
| 215 |
r.correct, |
| 216 |
r.total, |
| 217 |
lenient_pct, |
| 218 |
r.correct + r.acceptable, |
| 219 |
r.total |
| 220 |
); |
| 221 |
println!( |
| 222 |
" confidence: mean={:.2} median={:.2} P10={:.2}", |
| 223 |
conf_mean, conf_median, conf_p10 |
| 224 |
); |
| 225 |
|
| 226 |
|
| 227 |
if !r.misclassified.is_empty() { |
| 228 |
let show = r.misclassified.len().min(10); |
| 229 |
for (name, got) in &r.misclassified[..show] { |
| 230 |
println!(" MISS: {} -> {}", name, got.as_str()); |
| 231 |
} |
| 232 |
if r.misclassified.len() > 10 { |
| 233 |
println!( |
| 234 |
" ... and {} more misclassifications", |
| 235 |
r.misclassified.len() - 10 |
| 236 |
); |
| 237 |
} |
| 238 |
} |
| 239 |
println!(); |
| 240 |
|
| 241 |
grand_total += r.total; |
| 242 |
grand_correct += r.correct; |
| 243 |
grand_acceptable += r.acceptable; |
| 244 |
all_confidences.extend_from_slice(&r.confidences); |
| 245 |
} |
| 246 |
} |
| 247 |
|
| 248 |
let overall_strict = (grand_correct as f64 / grand_total as f64) * 100.0; |
| 249 |
let overall_lenient = |
| 250 |
((grand_correct + grand_acceptable) as f64 / grand_total as f64) * 100.0; |
| 251 |
|
| 252 |
println!("============================================================"); |
| 253 |
println!( |
| 254 |
"OVERALL {:>4} samples | strict: {:>5.1}% | lenient: {:>5.1}%", |
| 255 |
grand_total, overall_strict, overall_lenient |
| 256 |
); |
| 257 |
|
| 258 |
if !all_confidences.is_empty() { |
| 259 |
let mut sorted = all_confidences.clone(); |
| 260 |
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); |
| 261 |
let mean = sorted.iter().sum::<f64>() / sorted.len() as f64; |
| 262 |
let median = sorted[sorted.len() / 2]; |
| 263 |
println!( |
| 264 |
" confidence: mean={:.2} median={:.2}", |
| 265 |
mean, median |
| 266 |
); |
| 267 |
} |
| 268 |
println!("============================================================"); |
| 269 |
|
| 270 |
|
| 271 |
println!("\nCONFUSION SUMMARY (misclassifications by target class):"); |
| 272 |
for label in &labels { |
| 273 |
if let Some(r) = results.get(*label) { |
| 274 |
if r.misclassified.is_empty() { |
| 275 |
continue; |
| 276 |
} |
| 277 |
let mut confusion: HashMap<&str, usize> = HashMap::new(); |
| 278 |
for (_, got) in &r.misclassified { |
| 279 |
*confusion.entry(got.as_str()).or_default() += 1; |
| 280 |
} |
| 281 |
let mut sorted: Vec<_> = confusion.into_iter().collect(); |
| 282 |
sorted.sort_by_key(|x| std::cmp::Reverse(x.1)); |
| 283 |
let items: Vec<String> = sorted.iter().map(|(c, n)| format!("{}:{}", c, n)).collect(); |
| 284 |
println!(" {:<12} -> {}", label, items.join(", ")); |
| 285 |
} |
| 286 |
} |
| 287 |
|
| 288 |
|
| 289 |
println!("\nPER-CLASS METRICS:"); |
| 290 |
println!("{:<12} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1"); |
| 291 |
let drum_labels = ["kick", "snare", "hihat", "cymbal", "percussion"]; |
| 292 |
for target_label in &drum_labels { |
| 293 |
let target_class = expected_class(target_label).unwrap(); |
| 294 |
|
| 295 |
|
| 296 |
let tp = results |
| 297 |
.get(*target_label) |
| 298 |
.map(|r| r.correct) |
| 299 |
.unwrap_or(0) as f64; |
| 300 |
|
| 301 |
|
| 302 |
let fn_ = results |
| 303 |
.get(*target_label) |
| 304 |
.map(|r| r.misclassified.len()) |
| 305 |
.unwrap_or(0) as f64; |
| 306 |
|
| 307 |
|
| 308 |
let fp: f64 = labels |
| 309 |
.iter() |
| 310 |
.filter(|&&l| expected_class(l) != Some(target_class)) |
| 311 |
.filter_map(|l| results.get(*l)) |
| 312 |
.flat_map(|r| r.misclassified.iter()) |
| 313 |
.filter(|(_, got)| *got == target_class) |
| 314 |
.count() as f64; |
| 315 |
|
| 316 |
let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 }; |
| 317 |
let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 }; |
| 318 |
let f1 = if precision + recall > 0.0 { |
| 319 |
2.0 * precision * recall / (precision + recall) |
| 320 |
} else { |
| 321 |
0.0 |
| 322 |
}; |
| 323 |
|
| 324 |
println!( |
| 325 |
"{:<12} {:>8.1}% {:>8.1}% {:>8.1}%", |
| 326 |
target_label, |
| 327 |
precision * 100.0, |
| 328 |
recall * 100.0, |
| 329 |
f1 * 100.0 |
| 330 |
); |
| 331 |
} |
| 332 |
|
| 333 |
|
| 334 |
|
| 335 |
|
| 336 |
assert!( |
| 337 |
overall_lenient >= 40.0, |
| 338 |
"Overall lenient accuracy {:.1}% is below minimum threshold of 40%", |
| 339 |
overall_lenient |
| 340 |
); |
| 341 |
} |
| 342 |
|