Skip to main content

max / audiofiles

38.6 KB · 1247 lines History Blame Raw
1 //! Two-layer ML classification system.
2 //!
3 //! Layer 1 (rule-based): Broad class detection (Drum vs Bass/Vocal/Synth/etc.)
4 //! Layer 2 (Random Forest): Fine-grained drum sub-classification (Kick/Snare/HiHat/Cymbal/Percussion)
5 //!
6 //! The RF model is trained offline by `audiofiles-train` and embedded via `include_bytes!`.
7 //! If no trained model is available (empty trees), falls back to the rule-based `classify_full()`.
8
9 use std::sync::OnceLock;
10
11 use super::mfcc::MfccFeatures;
12 use super::spectral::SpectralFeatures;
13 use tracing::instrument;
14
15 /// Number of features in the classification feature vector.
16 pub const NUM_FEATURES: usize = 35;
17
18 // ── SampleClass enum (unchanged) ──
19
20 /// High-level classification of a sample's content.
21 #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
22 pub enum SampleClass {
23 Kick,
24 Snare,
25 HiHat,
26 Cymbal,
27 Clap,
28 Tom,
29 Percussion,
30 Bass,
31 GuitarBass,
32 SynthBass,
33 SubBass,
34 Vocal,
35 VocalChop,
36 VocalPhrase,
37 VocalChoir,
38 Synth,
39 SynthLead,
40 SynthStab,
41 SynthPluck,
42 SynthChord,
43 Pad,
44 Misc,
45 Noise,
46 Music,
47 Ambience,
48 Impact,
49 Foley,
50 Texture,
51 }
52
53 impl SampleClass {
54 pub fn as_str(&self) -> &'static str {
55 match self {
56 Self::Kick => "kick",
57 Self::Snare => "snare",
58 Self::HiHat => "hihat",
59 Self::Cymbal => "cymbal",
60 Self::Clap => "clap",
61 Self::Tom => "tom",
62 Self::Percussion => "percussion",
63 Self::Bass => "bass",
64 Self::GuitarBass => "guitar-bass",
65 Self::SynthBass => "synth-bass",
66 Self::SubBass => "sub-bass",
67 Self::Vocal => "vocal",
68 Self::VocalChop => "vocal-chop",
69 Self::VocalPhrase => "vocal-phrase",
70 Self::VocalChoir => "vocal-choir",
71 Self::Synth => "synth",
72 Self::SynthLead => "synth-lead",
73 Self::SynthStab => "synth-stab",
74 Self::SynthPluck => "synth-pluck",
75 Self::SynthChord => "synth-chord",
76 Self::Pad => "pad",
77 Self::Misc => "misc",
78 Self::Noise => "noise",
79 Self::Music => "music",
80 Self::Ambience => "ambience",
81 Self::Impact => "impact",
82 Self::Foley => "foley",
83 Self::Texture => "texture",
84 }
85 }
86
87 /// Dot-notation tag for this class.
88 pub fn tag(&self) -> &'static str {
89 match self {
90 Self::Kick => "instrument.drum.kick",
91 Self::Snare => "instrument.drum.snare",
92 Self::HiHat => "instrument.drum.hihat",
93 Self::Cymbal => "instrument.drum.cymbal",
94 Self::Clap => "instrument.drum.clap",
95 Self::Tom => "instrument.drum.tom",
96 Self::Percussion => "instrument.percussion",
97 Self::Bass => "instrument.bass",
98 Self::GuitarBass => "instrument.bass.guitar",
99 Self::SynthBass => "instrument.bass.synth",
100 Self::SubBass => "instrument.bass.sub",
101 Self::Vocal => "instrument.vocal",
102 Self::VocalChop => "instrument.vocal.chop",
103 Self::VocalPhrase => "instrument.vocal.phrase",
104 Self::VocalChoir => "instrument.vocal.choir",
105 Self::Synth => "instrument.synth",
106 Self::SynthLead => "instrument.synth.lead",
107 Self::SynthStab => "instrument.synth.stab",
108 Self::SynthPluck => "instrument.synth.pluck",
109 Self::SynthChord => "instrument.synth.chord",
110 Self::Pad => "instrument.pad",
111 Self::Misc => "character.misc",
112 Self::Noise => "character.noise",
113 Self::Music => "type.music",
114 Self::Ambience => "character.ambience",
115 Self::Impact => "character.impact",
116 Self::Foley => "character.foley",
117 Self::Texture => "character.texture",
118 }
119 }
120 }
121
122 impl std::fmt::Display for SampleClass {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 write!(f, "{}", self.as_str())
125 }
126 }
127
128 impl std::str::FromStr for SampleClass {
129 type Err = ();
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 match s {
133 "kick" => Ok(Self::Kick),
134 "snare" => Ok(Self::Snare),
135 "hihat" => Ok(Self::HiHat),
136 "cymbal" => Ok(Self::Cymbal),
137 "clap" => Ok(Self::Clap),
138 "tom" => Ok(Self::Tom),
139 "percussion" => Ok(Self::Percussion),
140 "bass" => Ok(Self::Bass),
141 "guitar-bass" => Ok(Self::GuitarBass),
142 "synth-bass" => Ok(Self::SynthBass),
143 "sub-bass" => Ok(Self::SubBass),
144 "vocal" => Ok(Self::Vocal),
145 "vocal-chop" => Ok(Self::VocalChop),
146 "vocal-phrase" => Ok(Self::VocalPhrase),
147 "vocal-choir" => Ok(Self::VocalChoir),
148 "synth" => Ok(Self::Synth),
149 "synth-lead" => Ok(Self::SynthLead),
150 "synth-stab" => Ok(Self::SynthStab),
151 "synth-pluck" => Ok(Self::SynthPluck),
152 "synth-chord" => Ok(Self::SynthChord),
153 "pad" => Ok(Self::Pad),
154 "misc" | "fx" => Ok(Self::Misc),
155 "noise" => Ok(Self::Noise),
156 "music" => Ok(Self::Music),
157 "ambience" => Ok(Self::Ambience),
158 "impact" => Ok(Self::Impact),
159 "foley" => Ok(Self::Foley),
160 "texture" => Ok(Self::Texture),
161 _ => Err(()),
162 }
163 }
164 }
165
166 // ── Broad class (Layer 1) ──
167
168 /// Broad classification for Layer 1 routing.
169 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
170 pub enum BroadClass {
171 Drum,
172 Bass,
173 Vocal,
174 Synth,
175 Pad,
176 Misc,
177 Noise,
178 Music,
179 Ambience,
180 Impact,
181 Foley,
182 Texture,
183 }
184
185 // ── Classification result ──
186
187 /// Result of the two-layer classifier.
188 #[derive(Debug, Clone)]
189 pub struct ClassificationResult {
190 pub class: SampleClass,
191 pub confidence: f64,
192 }
193
194 // ── Decision tree types (custom format for embedded inference) ──
195
196 /// A node in a serialized decision tree.
197 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
198 pub enum TreeNode {
199 Split {
200 feature: usize,
201 threshold: f64,
202 left: Box<TreeNode>,
203 right: Box<TreeNode>,
204 },
205 Leaf {
206 class: u8,
207 },
208 }
209
210 impl TreeNode {
211 pub fn predict(&self, features: &[f64; NUM_FEATURES]) -> u8 {
212 match self {
213 TreeNode::Split {
214 feature,
215 threshold,
216 left,
217 right,
218 } => {
219 if *feature >= NUM_FEATURES {
220 return 0; // fallback for malformed model
221 }
222 let val = features[*feature];
223 // NaN goes left (conservative path) instead of always right
224 if val.is_nan() || val <= *threshold {
225 left.predict(features)
226 } else {
227 right.predict(features)
228 }
229 }
230 TreeNode::Leaf { class } => *class,
231 }
232 }
233 }
234
235 /// A trained Random Forest model (collection of decision trees).
236 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
237 pub struct RandomForestModel {
238 pub trees: Vec<TreeNode>,
239 pub num_classes: u8,
240 pub class_names: Vec<String>,
241 }
242
243 // ── ClassifyInput ──
244
245 /// Input bundle for classification — collects all features used by the classifier.
246 pub struct ClassifyInput {
247 pub duration: f64,
248 pub centroid: f64,
249 pub flatness: f64,
250 pub zcr: f64,
251 pub onset_strength: f64,
252 pub bandwidth: f64,
253 pub centroid_variance: f64,
254 pub crest_factor: f64,
255 pub attack_time: f64,
256 pub mfcc_means: [f64; 13],
257 pub mfcc_variances: [f64; 13],
258 }
259
260 impl ClassifyInput {
261 /// Build from spectral features + waveform measurements (no MFCCs).
262 pub fn new(
263 features: &SpectralFeatures,
264 duration: f64,
265 crest_factor: f64,
266 attack_time: f64,
267 ) -> Self {
268 Self::with_mfccs(
269 features,
270 duration,
271 crest_factor,
272 attack_time,
273 &MfccFeatures::default(),
274 )
275 }
276
277 /// Build from spectral features + waveform measurements + MFCCs.
278 pub fn with_mfccs(
279 features: &SpectralFeatures,
280 duration: f64,
281 crest_factor: f64,
282 attack_time: f64,
283 mfccs: &MfccFeatures,
284 ) -> Self {
285 Self {
286 duration,
287 centroid: features.centroid,
288 flatness: features.flatness,
289 zcr: features.zero_crossing_rate,
290 onset_strength: features.onset_strength,
291 bandwidth: features.bandwidth,
292 centroid_variance: features.centroid_variance,
293 crest_factor,
294 attack_time,
295 mfcc_means: mfccs.means,
296 mfcc_variances: mfccs.variances,
297 }
298 }
299
300 /// Convert to a flat 35-element feature array for RF inference.
301 ///
302 /// Layout: [0-8] scalar features, [9-21] MFCC means, [22-34] MFCC variances.
303 pub fn to_feature_array(&self) -> [f64; NUM_FEATURES] {
304 let mut arr = [0.0; NUM_FEATURES];
305 arr[0] = self.duration;
306 arr[1] = self.centroid;
307 arr[2] = self.flatness;
308 arr[3] = self.zcr;
309 arr[4] = self.onset_strength;
310 arr[5] = self.bandwidth;
311 arr[6] = self.centroid_variance;
312 arr[7] = self.crest_factor;
313 arr[8] = self.attack_time;
314 arr[9..22].copy_from_slice(&self.mfcc_means);
315 arr[22..35].copy_from_slice(&self.mfcc_variances);
316 arr
317 }
318 }
319
320 // ── Model loading ──
321
322 /// Layer 2 drum model embedded at compile time.
323 const LAYER2_DRUM_BYTES: &[u8] = include_bytes!("../../models/layer2_drum.json");
324 const LAYER2_BASS_BYTES: &[u8] = include_bytes!("../../models/layer2_bass.json");
325 const LAYER2_VOCAL_BYTES: &[u8] = include_bytes!("../../models/layer2_vocal.json");
326 const LAYER2_SYNTH_BYTES: &[u8] = include_bytes!("../../models/layer2_synth.json");
327
328 /// Lazily deserialized Layer 2 models (one per broad class).
329 static LAYER2_DRUM_MODEL: OnceLock<RandomForestModel> = OnceLock::new();
330 static LAYER2_BASS_MODEL: OnceLock<RandomForestModel> = OnceLock::new();
331 static LAYER2_VOCAL_MODEL: OnceLock<RandomForestModel> = OnceLock::new();
332 static LAYER2_SYNTH_MODEL: OnceLock<RandomForestModel> = OnceLock::new();
333
334 /// Fallback empty model — classification falls back to rule-based layer 1 only.
335 fn empty_model() -> RandomForestModel {
336 RandomForestModel {
337 trees: vec![],
338 num_classes: 0,
339 class_names: vec![],
340 }
341 }
342
343 fn layer2_model() -> &'static RandomForestModel {
344 LAYER2_DRUM_MODEL.get_or_init(|| {
345 serde_json::from_slice(LAYER2_DRUM_BYTES).unwrap_or_else(|e| {
346 tracing::error!("Failed to deserialize embedded drum model: {e}");
347 empty_model()
348 })
349 })
350 }
351
352 fn layer2_bass_model() -> &'static RandomForestModel {
353 LAYER2_BASS_MODEL.get_or_init(|| {
354 serde_json::from_slice(LAYER2_BASS_BYTES).unwrap_or_else(|e| {
355 tracing::error!("Failed to deserialize embedded bass model: {e}");
356 empty_model()
357 })
358 })
359 }
360
361 fn layer2_vocal_model() -> &'static RandomForestModel {
362 LAYER2_VOCAL_MODEL.get_or_init(|| {
363 serde_json::from_slice(LAYER2_VOCAL_BYTES).unwrap_or_else(|e| {
364 tracing::error!("Failed to deserialize embedded vocal model: {e}");
365 empty_model()
366 })
367 })
368 }
369
370 fn layer2_synth_model() -> &'static RandomForestModel {
371 LAYER2_SYNTH_MODEL.get_or_init(|| {
372 serde_json::from_slice(LAYER2_SYNTH_BYTES).unwrap_or_else(|e| {
373 tracing::error!("Failed to deserialize embedded synth model: {e}");
374 empty_model()
375 })
376 })
377 }
378
379 /// Class labels per Layer 2 model. Index corresponds to class ID in TreeNode::Leaf.
380 const DRUM_CLASSES: [SampleClass; 7] = [
381 SampleClass::Kick,
382 SampleClass::Snare,
383 SampleClass::HiHat,
384 SampleClass::Cymbal,
385 SampleClass::Clap,
386 SampleClass::Tom,
387 SampleClass::Percussion,
388 ];
389
390 const BASS_CLASSES: [SampleClass; 3] = [
391 SampleClass::GuitarBass,
392 SampleClass::SynthBass,
393 SampleClass::SubBass,
394 ];
395
396 const VOCAL_CLASSES: [SampleClass; 3] = [
397 SampleClass::VocalChop,
398 SampleClass::VocalPhrase,
399 SampleClass::VocalChoir,
400 ];
401
402 const SYNTH_CLASSES: [SampleClass; 4] = [
403 SampleClass::SynthLead,
404 SampleClass::SynthStab,
405 SampleClass::SynthPluck,
406 SampleClass::SynthChord,
407 ];
408
409 // ── Layer 1: Rule-based broad classifier ──
410
411 /// Classify a sample into a broad category (Layer 1).
412 ///
413 /// Uses simplified rules from the original `classify_full()` to determine if a sample
414 /// is a drum hit vs other content types. Drum detection heuristic:
415 /// short duration + fast attack + high crest factor.
416 fn classify_broad(input: &ClassifyInput) -> (BroadClass, f64) {
417 let d = input.duration;
418 let c = input.centroid;
419 let flat = input.flatness;
420 let zcr = input.zcr;
421 let bw = input.bandwidth;
422 let cv = input.centroid_variance;
423 let crest = input.crest_factor;
424 let attack = input.attack_time;
425
426 // Noise: energy spread nearly uniformly — but not short percussive sounds
427 // (cymbals/hihats can have high flatness but are still drums)
428 if flat > 0.7 && (d > 2.0 || attack > 0.1) {
429 return (BroadClass::Noise, 0.9);
430 }
431
432 // Bright metallic sounds (cymbals, crashes) are drums even when sustained
433 if d < 10.0 && c > 3000.0 && flat > 0.2 {
434 return (BroadClass::Drum, 0.85);
435 }
436
437 // Drum detection: short + percussive transient
438 if d < 2.0 && (attack < 0.05 || crest > 2.5) {
439 let mut conf: f64 = 0.8;
440 if attack < 0.02 {
441 conf += 0.05;
442 }
443 if crest > 4.0 {
444 conf += 0.05;
445 }
446 if d < 1.0 {
447 conf += 0.05;
448 }
449 return (BroadClass::Drum, conf.min(0.95));
450 }
451
452 // Extended drum catch: short sounds with fast attack that didn't match above
453 // (catches deeper kicks with low crest, toms with moderate attack)
454 if d < 2.0 && attack < 0.1 && crest > 1.5 {
455 return (BroadClass::Drum, 0.75);
456 }
457
458 // Impact: very sharp non-drum transient (must be longer than typical drums)
459 if d > 1.0 && d < 5.0 && crest > 10.0 && attack < 0.005 {
460 return (BroadClass::Impact, 0.85);
461 }
462
463 // Ambience: long, spectrally static, moderate noise
464 if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 {
465 return (BroadClass::Ambience, 0.85);
466 }
467
468 // Bass: low-frequency, tonal — but not short percussive sounds (kicks)
469 if c < 400.0 && flat < 0.15 && d > 2.0 {
470 return (BroadClass::Bass, 0.85);
471 }
472
473 // Vocal: mid-range, tonal, smooth waveform — require longer duration to avoid catching toms
474 if d > 1.0 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 && crest < 2.0 {
475 return (BroadClass::Vocal, 0.8);
476 }
477
478 // Pad: long, tonal, mid-range
479 if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 {
480 return (BroadClass::Pad, 0.8);
481 }
482
483 // Texture: long, spectrally evolving
484 if d > 2.0 && cv > 500_000.0 {
485 return (BroadClass::Texture, 0.8);
486 }
487
488 // Foley: broadband, moderate noise — require longer duration to avoid catching drums
489 if d > 1.0 && bw > 2000.0 && flat > 0.1 && flat < 0.5 {
490 return (BroadClass::Foley, 0.75);
491 }
492
493 // Synth: tonal, mid-to-high centroid — require low crest to exclude drums
494 if c > 500.0 && flat < 0.3 && zcr < 0.1 && crest < 2.0 {
495 return (BroadClass::Synth, 0.75);
496 }
497
498 // Music: long catch-all
499 if d > 3.0 {
500 return (BroadClass::Music, 0.6);
501 }
502
503 // Short unclassified sounds are likely drums
504 if d < 2.0 {
505 return (BroadClass::Drum, 0.6);
506 }
507
508 // Misc: final catch-all (nothing else matched)
509 (BroadClass::Misc, 0.5)
510 }
511
512 /// Map a BroadClass (non-Drum) to the corresponding SampleClass.
513 fn broad_to_sample_class(broad: BroadClass) -> SampleClass {
514 match broad {
515 BroadClass::Drum => SampleClass::Percussion, // shouldn't reach here
516 BroadClass::Bass => SampleClass::Bass,
517 BroadClass::Vocal => SampleClass::Vocal,
518 BroadClass::Synth => SampleClass::Synth,
519 BroadClass::Pad => SampleClass::Pad,
520 BroadClass::Misc => SampleClass::Misc,
521 BroadClass::Noise => SampleClass::Noise,
522 BroadClass::Music => SampleClass::Music,
523 BroadClass::Ambience => SampleClass::Ambience,
524 BroadClass::Impact => SampleClass::Impact,
525 BroadClass::Foley => SampleClass::Foley,
526 BroadClass::Texture => SampleClass::Texture,
527 }
528 }
529
530 // ── Layer 2: Random Forest inference ──
531
532 /// Run Layer 2 classification using a model and its class mapping.
533 ///
534 /// Returns (SampleClass, confidence) where confidence is the vote fraction.
535 /// If the model has no trees, returns the fallback class with 0 confidence.
536 fn predict_with_model(
537 model: &RandomForestModel,
538 classes: &[SampleClass],
539 fallback: SampleClass,
540 features: &[f64; NUM_FEATURES],
541 ) -> (SampleClass, f64) {
542 if model.trees.is_empty() {
543 return (fallback, 0.0);
544 }
545
546 let mut votes = vec![0u32; model.num_classes as usize];
547 for tree in &model.trees {
548 let class_id = tree.predict(features) as usize;
549 if class_id < votes.len() {
550 votes[class_id] += 1;
551 }
552 }
553
554 let total = model.trees.len() as f64;
555 let (best_class_id, &best_count) = votes
556 .iter()
557 .enumerate()
558 .max_by_key(|(_, count)| **count)
559 .unwrap_or((0, &0));
560
561 let confidence = best_count as f64 / total;
562 let class = classes
563 .get(best_class_id)
564 .copied()
565 .unwrap_or(fallback);
566
567 (class, confidence)
568 }
569
570 // ── Main entry point ──
571
572 /// Two-layer ML classifier.
573 ///
574 /// Layer 1 (rule-based) determines broad class. Layer 2 (Random Forest) provides
575 /// fine-grained sub-classification for Drum, Bass, Vocal, and Synth classes.
576 /// If a Layer 2 model is not trained (empty trees), falls back to the broad class.
577 #[instrument(skip_all)]
578 pub fn classify_ml(input: &ClassifyInput) -> ClassificationResult {
579 let drum_model = layer2_model();
580
581 // If no drum model at all, fall back entirely to rule-based
582 if drum_model.trees.is_empty() {
583 return ClassificationResult {
584 class: classify_full(input),
585 confidence: 0.0,
586 };
587 }
588
589 let (broad, broad_conf) = classify_broad(input);
590 let features = input.to_feature_array();
591
592 match broad {
593 BroadClass::Drum => {
594 let (class, conf) = predict_with_model(drum_model, &DRUM_CLASSES, SampleClass::Percussion, &features);
595 ClassificationResult { class, confidence: conf }
596 }
597 BroadClass::Bass => {
598 let model = layer2_bass_model();
599 if model.trees.is_empty() {
600 ClassificationResult { class: SampleClass::Bass, confidence: broad_conf }
601 } else {
602 let (class, conf) = predict_with_model(model, &BASS_CLASSES, SampleClass::Bass, &features);
603 ClassificationResult { class, confidence: conf }
604 }
605 }
606 BroadClass::Vocal => {
607 let model = layer2_vocal_model();
608 if model.trees.is_empty() {
609 ClassificationResult { class: SampleClass::Vocal, confidence: broad_conf }
610 } else {
611 let (class, conf) = predict_with_model(model, &VOCAL_CLASSES, SampleClass::Vocal, &features);
612 ClassificationResult { class, confidence: conf }
613 }
614 }
615 BroadClass::Synth => {
616 let model = layer2_synth_model();
617 if model.trees.is_empty() {
618 ClassificationResult { class: SampleClass::Synth, confidence: broad_conf }
619 } else {
620 let (class, conf) = predict_with_model(model, &SYNTH_CLASSES, SampleClass::Synth, &features);
621 ClassificationResult { class, confidence: conf }
622 }
623 }
624 _ => ClassificationResult {
625 class: broad_to_sample_class(broad),
626 confidence: broad_conf,
627 },
628 }
629 }
630
631 // ── Smart skip helpers ──
632
633 impl SampleClass {
634 /// Whether BPM detection is meaningful for this sample type.
635 ///
636 /// Drums, impacts, noise, foley, ambience, and textures have no rhythmic
637 /// content that BPM detection can usefully extract. Short chops also skip.
638 pub fn has_rhythm(&self) -> bool {
639 matches!(
640 self,
641 Self::Bass
642 | Self::GuitarBass
643 | Self::SynthBass
644 | Self::SubBass
645 | Self::Vocal
646 | Self::VocalPhrase
647 | Self::VocalChoir
648 | Self::Synth
649 | Self::SynthLead
650 | Self::SynthChord
651 | Self::Pad
652 | Self::Music
653 | Self::Misc
654 )
655 }
656
657 /// Whether key detection is meaningful for this sample type.
658 ///
659 /// Noise, foley, ambience, textures, and drum one-shots lack tonal content.
660 /// Short chops and stabs are too brief for reliable key detection.
661 pub fn has_pitch(&self) -> bool {
662 matches!(
663 self,
664 Self::Bass
665 | Self::GuitarBass
666 | Self::SynthBass
667 | Self::SubBass
668 | Self::Vocal
669 | Self::VocalPhrase
670 | Self::VocalChoir
671 | Self::Synth
672 | Self::SynthLead
673 | Self::SynthChord
674 | Self::Pad
675 | Self::Music
676 | Self::Misc
677 )
678 }
679 }
680
681 // ── Legacy functions ──
682
683 /// Classify a sample using the full feature set (16 classes, priority-ordered decision tree).
684 ///
685 /// Uses threshold-based rules — no ML. Rules are evaluated in priority order. The tree is
686 /// structured in two phases: drum/percussion rules first (for short sounds with percussive
687 /// transients), then non-drum rules for sustained/tonal content. This prevents tonal drum
688 /// hits from leaking into Vocal/Synth/Foley categories.
689 ///
690 /// The expanded feature set (crest factor, attack time, bandwidth, centroid variance)
691 /// enables game audio categories (ambience, impact, foley, texture) that the original
692 /// 4-feature classifier collapsed into Misc/Percussion/Music.
693 #[instrument(skip_all)]
694 pub fn classify_full(input: &ClassifyInput) -> SampleClass {
695 let d = input.duration;
696 let c = input.centroid;
697 let flat = input.flatness;
698 let zcr = input.zcr;
699 let _onset = input.onset_strength;
700 let bw = input.bandwidth;
701 let cv = input.centroid_variance;
702 let crest = input.crest_factor;
703 let attack = input.attack_time;
704
705 // 1. Noise: energy spread nearly uniformly across all frequencies
706 if flat > 0.7 {
707 return SampleClass::Noise;
708 }
709
710 // --- Phase 1: Drum/percussion classification (short sounds first) ---
711
712 // 2. Kick: short, low-frequency dominant
713 if d < 2.0 && c < 1200.0 && flat < 0.35 {
714 return SampleClass::Kick;
715 }
716
717 // 3. HiHat: short, bright, metallic (relatively tonal)
718 if d < 2.0 && c > 3500.0 && zcr > 0.1 && flat < 0.3 {
719 return SampleClass::HiHat;
720 }
721
722 // 4. Cymbal: bright, noisy sustain (crashes, rides)
723 if d > 0.5 && d < 10.0 && c > 3000.0 && flat > 0.25 {
724 return SampleClass::Cymbal;
725 }
726
727 // 5. Snare: short, mid-to-high frequency, broadband noise from snare wires
728 if d < 2.0 && c > 800.0 && flat > 0.2 && bw > 2000.0 {
729 return SampleClass::Snare;
730 }
731
732 // 6. Percussion: short percussive catch-all
733 if d < 2.0 && (attack < 0.05 || crest > 3.0) {
734 return SampleClass::Percussion;
735 }
736
737 // --- Phase 2: Non-drum classification ---
738
739 // 7. Impact: very sharp non-drum transient
740 if d < 3.0 && crest > 10.0 && attack < 0.005 {
741 return SampleClass::Impact;
742 }
743
744 // 8. Ambience: long, spectrally static, moderate noise
745 if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 {
746 return SampleClass::Ambience;
747 }
748
749 // 9. Bass: low-frequency, tonal
750 if c < 400.0 && flat < 0.15 {
751 return SampleClass::Bass;
752 }
753
754 // 10. Vocal: mid-range, tonal, smooth waveform, sustained
755 if d > 0.5 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 {
756 return SampleClass::Vocal;
757 }
758
759 // 11. Pad: long, tonal, mid-range
760 if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 {
761 return SampleClass::Pad;
762 }
763
764 // 12. Texture: long, spectrally evolving
765 if d > 2.0 && cv > 500_000.0 {
766 return SampleClass::Texture;
767 }
768
769 // 13. Foley: broadband, moderate noise
770 if d > 0.1 && bw > 2000.0 && flat > 0.1 && flat < 0.5 {
771 return SampleClass::Foley;
772 }
773
774 // 14. Synth: tonal, mid-to-high centroid, smooth waveform
775 if c > 500.0 && flat < 0.3 && zcr < 0.1 {
776 return SampleClass::Synth;
777 }
778
779 // 15. Percussion: short catch-all
780 if d < 2.0 {
781 return SampleClass::Percussion;
782 }
783
784 // 16. Music: long catch-all
785 if d > 3.0 {
786 return SampleClass::Music;
787 }
788
789 // 17. Misc: final catch-all (nothing else matched)
790 SampleClass::Misc
791 }
792
793 /// Legacy classify function — uses only spectral features and duration.
794 #[instrument(skip_all)]
795 pub fn classify(features: &SpectralFeatures, duration: f64) -> Option<SampleClass> {
796 let input = ClassifyInput::new(features, duration, 0.0, 0.0);
797 Some(classify_full(&input))
798 }
799
800 #[cfg(test)]
801 mod tests {
802 use super::*;
803
804 #[test]
805 fn kick_classification() {
806 let features = SpectralFeatures {
807 centroid: 600.0,
808 flatness: 0.15,
809 rolloff: 1200.0,
810 zero_crossing_rate: 0.04,
811 onset_strength: 50.0,
812 ..Default::default()
813 };
814 assert_eq!(classify(&features, 0.3), Some(SampleClass::Kick));
815 }
816
817 #[test]
818 fn hihat_classification() {
819 let features = SpectralFeatures {
820 centroid: 5000.0,
821 flatness: 0.2,
822 rolloff: 12000.0,
823 zero_crossing_rate: 0.15,
824 onset_strength: 30.0,
825 ..Default::default()
826 };
827 assert_eq!(classify(&features, 0.2), Some(SampleClass::HiHat));
828 }
829
830 #[test]
831 fn snare_classification() {
832 let input = ClassifyInput {
833 duration: 0.5,
834 centroid: 2500.0,
835 flatness: 0.25,
836 zcr: 0.12,
837 onset_strength: 40.0,
838 bandwidth: 3000.0,
839 centroid_variance: 200_000.0,
840 crest_factor: 6.0,
841 attack_time: 0.003,
842 mfcc_means: [0.0; 13],
843 mfcc_variances: [0.0; 13],
844 };
845 assert_eq!(classify_full(&input), SampleClass::Snare);
846 }
847
848 #[test]
849 fn noise_classification() {
850 let features = SpectralFeatures {
851 centroid: 5000.0,
852 flatness: 0.75,
853 rolloff: 15000.0,
854 zero_crossing_rate: 0.3,
855 onset_strength: 10.0,
856 ..Default::default()
857 };
858 assert_eq!(classify(&features, 2.0), Some(SampleClass::Noise));
859 }
860
861 #[test]
862 fn bass_classification() {
863 let features = SpectralFeatures {
864 centroid: 200.0,
865 flatness: 0.08,
866 rolloff: 500.0,
867 zero_crossing_rate: 0.03,
868 onset_strength: 20.0,
869 ..Default::default()
870 };
871 assert_eq!(classify(&features, 3.0), Some(SampleClass::Bass));
872 }
873
874 #[test]
875 fn tag_format() {
876 assert_eq!(SampleClass::Kick.tag(), "instrument.drum.kick");
877 assert_eq!(SampleClass::Noise.tag(), "character.noise");
878 assert_eq!(SampleClass::Music.tag(), "type.music");
879 assert_eq!(SampleClass::Ambience.tag(), "character.ambience");
880 assert_eq!(SampleClass::Impact.tag(), "character.impact");
881 assert_eq!(SampleClass::Foley.tag(), "character.foley");
882 assert_eq!(SampleClass::Texture.tag(), "character.texture");
883 }
884
885 #[test]
886 fn impact_classification() {
887 let input = ClassifyInput {
888 duration: 2.5,
889 centroid: 1500.0,
890 flatness: 0.1,
891 zcr: 0.06,
892 onset_strength: 50.0,
893 bandwidth: 2000.0,
894 centroid_variance: 200_000.0,
895 crest_factor: 12.0,
896 attack_time: 0.002,
897 mfcc_means: [0.0; 13],
898 mfcc_variances: [0.0; 13],
899 };
900 assert_eq!(classify_full(&input), SampleClass::Impact);
901 }
902
903 #[test]
904 fn ambience_classification() {
905 let input = ClassifyInput {
906 duration: 10.0,
907 centroid: 2000.0,
908 flatness: 0.3,
909 zcr: 0.05,
910 onset_strength: 5.0,
911 bandwidth: 3000.0,
912 centroid_variance: 50_000.0,
913 crest_factor: 2.5,
914 attack_time: 0.5,
915 mfcc_means: [0.0; 13],
916 mfcc_variances: [0.0; 13],
917 };
918 assert_eq!(classify_full(&input), SampleClass::Ambience);
919 }
920
921 #[test]
922 fn foley_classification() {
923 let input = ClassifyInput {
924 duration: 3.0,
925 centroid: 1500.0,
926 flatness: 0.25,
927 zcr: 0.06,
928 onset_strength: 15.0,
929 bandwidth: 4000.0,
930 centroid_variance: 300_000.0,
931 crest_factor: 2.0,
932 attack_time: 0.1,
933 mfcc_means: [0.0; 13],
934 mfcc_variances: [0.0; 13],
935 };
936 assert_eq!(classify_full(&input), SampleClass::Foley);
937 }
938
939 #[test]
940 fn texture_classification() {
941 let input = ClassifyInput {
942 duration: 5.0,
943 centroid: 2000.0,
944 flatness: 0.12,
945 zcr: 0.12,
946 onset_strength: 10.0,
947 bandwidth: 3000.0,
948 centroid_variance: 1_000_000.0,
949 crest_factor: 3.0,
950 attack_time: 0.1,
951 mfcc_means: [0.0; 13],
952 mfcc_variances: [0.0; 13],
953 };
954 assert_eq!(classify_full(&input), SampleClass::Texture);
955 }
956
957 #[test]
958 fn round_trip_from_str() {
959 for class in [
960 SampleClass::Kick,
961 SampleClass::Snare,
962 SampleClass::HiHat,
963 SampleClass::Cymbal,
964 SampleClass::Clap,
965 SampleClass::Tom,
966 SampleClass::Percussion,
967 SampleClass::Bass,
968 SampleClass::GuitarBass,
969 SampleClass::SynthBass,
970 SampleClass::SubBass,
971 SampleClass::Vocal,
972 SampleClass::VocalChop,
973 SampleClass::VocalPhrase,
974 SampleClass::VocalChoir,
975 SampleClass::Synth,
976 SampleClass::SynthLead,
977 SampleClass::SynthStab,
978 SampleClass::SynthPluck,
979 SampleClass::SynthChord,
980 SampleClass::Pad,
981 SampleClass::Misc,
982 SampleClass::Noise,
983 SampleClass::Music,
984 SampleClass::Ambience,
985 SampleClass::Impact,
986 SampleClass::Foley,
987 SampleClass::Texture,
988 ] {
989 let s = class.as_str();
990 let parsed: SampleClass = s.parse().unwrap();
991 assert_eq!(parsed, class, "round-trip failed for {s}");
992 }
993 }
994
995 #[test]
996 fn feature_array_layout() {
997 let input = ClassifyInput {
998 duration: 1.0,
999 centroid: 2.0,
1000 flatness: 3.0,
1001 zcr: 4.0,
1002 onset_strength: 5.0,
1003 bandwidth: 6.0,
1004 centroid_variance: 7.0,
1005 crest_factor: 8.0,
1006 attack_time: 9.0,
1007 mfcc_means: [10.0; 13],
1008 mfcc_variances: [20.0; 13],
1009 };
1010 let arr = input.to_feature_array();
1011 assert_eq!(arr[0], 1.0); // duration
1012 assert_eq!(arr[1], 2.0); // centroid
1013 assert_eq!(arr[8], 9.0); // attack_time
1014 assert_eq!(arr[9], 10.0); // mfcc_mean[0]
1015 assert_eq!(arr[21], 10.0); // mfcc_mean[12]
1016 assert_eq!(arr[22], 20.0); // mfcc_var[0]
1017 assert_eq!(arr[34], 20.0); // mfcc_var[12]
1018 }
1019
1020 #[test]
1021 fn classify_ml_returns_drum_for_kick_input() {
1022 let input = ClassifyInput {
1023 duration: 0.3,
1024 centroid: 600.0,
1025 flatness: 0.15,
1026 zcr: 0.04,
1027 onset_strength: 50.0,
1028 bandwidth: 500.0,
1029 centroid_variance: 10_000.0,
1030 crest_factor: 5.0,
1031 attack_time: 0.003,
1032 mfcc_means: [0.0; 13],
1033 mfcc_variances: [0.0; 13],
1034 };
1035 let result = classify_ml(&input);
1036 // Should classify as a drum class with nonzero confidence
1037 assert!(
1038 matches!(
1039 result.class,
1040 SampleClass::Kick
1041 | SampleClass::Snare
1042 | SampleClass::HiHat
1043 | SampleClass::Cymbal
1044 | SampleClass::Clap
1045 | SampleClass::Tom
1046 | SampleClass::Percussion
1047 ),
1048 "expected drum class, got {:?}",
1049 result.class
1050 );
1051 }
1052
1053 #[test]
1054 fn tree_node_prediction() {
1055 let tree = TreeNode::Split {
1056 feature: 1,
1057 threshold: 1000.0,
1058 left: Box::new(TreeNode::Leaf { class: 0 }), // kick (low centroid)
1059 right: Box::new(TreeNode::Leaf { class: 2 }), // hihat (high centroid)
1060 };
1061 let mut features = [0.0; NUM_FEATURES];
1062 features[1] = 500.0; // centroid below threshold
1063 assert_eq!(tree.predict(&features), 0);
1064 features[1] = 5000.0; // centroid above threshold
1065 assert_eq!(tree.predict(&features), 2);
1066 }
1067
1068 #[test]
1069 fn broad_classifier_detects_drums() {
1070 let input = ClassifyInput {
1071 duration: 0.3,
1072 centroid: 600.0,
1073 flatness: 0.15,
1074 zcr: 0.04,
1075 onset_strength: 50.0,
1076 bandwidth: 500.0,
1077 centroid_variance: 10_000.0,
1078 crest_factor: 5.0,
1079 attack_time: 0.003,
1080 mfcc_means: [0.0; 13],
1081 mfcc_variances: [0.0; 13],
1082 };
1083 let (broad, conf) = classify_broad(&input);
1084 assert_eq!(broad, BroadClass::Drum);
1085 assert!(conf >= 0.8);
1086 }
1087
1088 #[test]
1089 fn broad_classifier_detects_noise() {
1090 let input = ClassifyInput {
1091 duration: 2.0,
1092 centroid: 5000.0,
1093 flatness: 0.8,
1094 zcr: 0.3,
1095 onset_strength: 10.0,
1096 bandwidth: 5000.0,
1097 centroid_variance: 100_000.0,
1098 crest_factor: 1.5,
1099 attack_time: 0.5,
1100 mfcc_means: [0.0; 13],
1101 mfcc_variances: [0.0; 13],
1102 };
1103 let (broad, _) = classify_broad(&input);
1104 assert_eq!(broad, BroadClass::Noise);
1105 }
1106
1107 #[test]
1108 fn smart_skip_drum_classes_skip_bpm_key() {
1109 // All drum sub-classes should skip BPM/key
1110 for class in [
1111 SampleClass::Kick,
1112 SampleClass::Snare,
1113 SampleClass::HiHat,
1114 SampleClass::Cymbal,
1115 SampleClass::Clap,
1116 SampleClass::Tom,
1117 SampleClass::Percussion,
1118 ] {
1119 assert!(!class.has_rhythm(), "{class:?} should not have rhythm");
1120 assert!(!class.has_pitch(), "{class:?} should not have pitch");
1121 }
1122 }
1123
1124 #[test]
1125 fn smart_skip_tonal_classes_keep_bpm_key() {
1126 for class in [
1127 SampleClass::Bass,
1128 SampleClass::GuitarBass,
1129 SampleClass::SynthBass,
1130 SampleClass::SubBass,
1131 SampleClass::Vocal,
1132 SampleClass::VocalPhrase,
1133 SampleClass::VocalChoir,
1134 SampleClass::Synth,
1135 SampleClass::SynthLead,
1136 SampleClass::SynthChord,
1137 SampleClass::Pad,
1138 SampleClass::Music,
1139 ] {
1140 assert!(class.has_rhythm(), "{class:?} should have rhythm");
1141 assert!(class.has_pitch(), "{class:?} should have pitch");
1142 }
1143 }
1144
1145 #[test]
1146 fn smart_skip_short_classes_skip_bpm_key() {
1147 // Short chops and stabs are too brief for BPM/key detection
1148 for class in [
1149 SampleClass::VocalChop,
1150 SampleClass::SynthStab,
1151 SampleClass::SynthPluck,
1152 ] {
1153 assert!(!class.has_rhythm(), "{class:?} should not have rhythm");
1154 assert!(!class.has_pitch(), "{class:?} should not have pitch");
1155 }
1156 }
1157
1158 #[test]
1159 fn smart_skip_non_tonal_non_drum_skip() {
1160 for class in [
1161 SampleClass::Noise,
1162 SampleClass::Ambience,
1163 SampleClass::Impact,
1164 SampleClass::Foley,
1165 SampleClass::Texture,
1166 ] {
1167 assert!(!class.has_rhythm(), "{class:?} should not have rhythm");
1168 assert!(!class.has_pitch(), "{class:?} should not have pitch");
1169 }
1170 }
1171
1172 #[test]
1173 fn smart_skip_pad_has_pitch_but_no_rhythm() {
1174 // Pads are tonal (key applies) and sustained (could have rhythm)
1175 // Current design: both true for pads
1176 assert!(SampleClass::Pad.has_pitch());
1177 assert!(SampleClass::Pad.has_rhythm());
1178 }
1179
1180 #[test]
1181 fn classify_ml_bass_routes_to_sub_class() {
1182 // Bass-like input: low centroid, tonal, sustained
1183 let input = ClassifyInput {
1184 duration: 3.0,
1185 centroid: 200.0,
1186 flatness: 0.08,
1187 zcr: 0.03,
1188 onset_strength: 20.0,
1189 bandwidth: 400.0,
1190 centroid_variance: 50_000.0,
1191 crest_factor: 1.5,
1192 attack_time: 0.05,
1193 mfcc_means: [0.0; 13],
1194 mfcc_variances: [0.0; 13],
1195 };
1196 let result = classify_ml(&input);
1197 // With trained bass model, should classify as a bass sub-class
1198 assert!(
1199 matches!(
1200 result.class,
1201 SampleClass::GuitarBass | SampleClass::SynthBass | SampleClass::SubBass
1202 ),
1203 "expected bass sub-class, got {:?}",
1204 result.class
1205 );
1206 assert!(result.confidence > 0.0);
1207 }
1208
1209 #[test]
1210 fn classify_ml_vocal_fallback_without_model() {
1211 // Vocal-like input: mid-range, tonal, smooth, sustained
1212 let input = ClassifyInput {
1213 duration: 2.0,
1214 centroid: 1000.0,
1215 flatness: 0.1,
1216 zcr: 0.05,
1217 onset_strength: 15.0,
1218 bandwidth: 1500.0,
1219 centroid_variance: 80_000.0,
1220 crest_factor: 1.5,
1221 attack_time: 0.2,
1222 mfcc_means: [0.0; 13],
1223 mfcc_variances: [0.0; 13],
1224 };
1225 let result = classify_ml(&input);
1226 assert_eq!(result.class, SampleClass::Vocal);
1227 }
1228
1229 #[test]
1230 fn predict_with_model_returns_fallback_for_empty() {
1231 let empty_model = RandomForestModel {
1232 trees: vec![],
1233 num_classes: 3,
1234 class_names: vec!["a".into(), "b".into(), "c".into()],
1235 };
1236 let features = [0.0; NUM_FEATURES];
1237 let (class, conf) = predict_with_model(
1238 &empty_model,
1239 &BASS_CLASSES,
1240 SampleClass::Bass,
1241 &features,
1242 );
1243 assert_eq!(class, SampleClass::Bass);
1244 assert_eq!(conf, 0.0);
1245 }
1246 }
1247