//! Two-layer ML classification system. //! //! Layer 1 (rule-based): Broad class detection (Drum vs Bass/Vocal/Synth/etc.) //! Layer 2 (Random Forest): Fine-grained drum sub-classification (Kick/Snare/HiHat/Cymbal/Percussion) //! //! The RF model is trained offline by `audiofiles-train` and embedded via `include_bytes!`. //! If no trained model is available (empty trees), falls back to the rule-based `classify_full()`. use std::sync::OnceLock; use super::mfcc::MfccFeatures; use super::spectral::SpectralFeatures; use tracing::instrument; /// Number of features in the classification feature vector. pub const NUM_FEATURES: usize = 35; // ── SampleClass enum (unchanged) ── /// High-level classification of a sample's content. #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum SampleClass { Kick, Snare, HiHat, Cymbal, Clap, Tom, Percussion, Bass, GuitarBass, SynthBass, SubBass, Vocal, VocalChop, VocalPhrase, VocalChoir, Synth, SynthLead, SynthStab, SynthPluck, SynthChord, Pad, Misc, Noise, Music, Ambience, Impact, Foley, Texture, } impl SampleClass { pub fn as_str(&self) -> &'static str { match self { Self::Kick => "kick", Self::Snare => "snare", Self::HiHat => "hihat", Self::Cymbal => "cymbal", Self::Clap => "clap", Self::Tom => "tom", Self::Percussion => "percussion", Self::Bass => "bass", Self::GuitarBass => "guitar-bass", Self::SynthBass => "synth-bass", Self::SubBass => "sub-bass", Self::Vocal => "vocal", Self::VocalChop => "vocal-chop", Self::VocalPhrase => "vocal-phrase", Self::VocalChoir => "vocal-choir", Self::Synth => "synth", Self::SynthLead => "synth-lead", Self::SynthStab => "synth-stab", Self::SynthPluck => "synth-pluck", Self::SynthChord => "synth-chord", Self::Pad => "pad", Self::Misc => "misc", Self::Noise => "noise", Self::Music => "music", Self::Ambience => "ambience", Self::Impact => "impact", Self::Foley => "foley", Self::Texture => "texture", } } /// Dot-notation tag for this class. pub fn tag(&self) -> &'static str { match self { Self::Kick => "instrument.drum.kick", Self::Snare => "instrument.drum.snare", Self::HiHat => "instrument.drum.hihat", Self::Cymbal => "instrument.drum.cymbal", Self::Clap => "instrument.drum.clap", Self::Tom => "instrument.drum.tom", Self::Percussion => "instrument.percussion", Self::Bass => "instrument.bass", Self::GuitarBass => "instrument.bass.guitar", Self::SynthBass => "instrument.bass.synth", Self::SubBass => "instrument.bass.sub", Self::Vocal => "instrument.vocal", Self::VocalChop => "instrument.vocal.chop", Self::VocalPhrase => "instrument.vocal.phrase", Self::VocalChoir => "instrument.vocal.choir", Self::Synth => "instrument.synth", Self::SynthLead => "instrument.synth.lead", Self::SynthStab => "instrument.synth.stab", Self::SynthPluck => "instrument.synth.pluck", Self::SynthChord => "instrument.synth.chord", Self::Pad => "instrument.pad", Self::Misc => "character.misc", Self::Noise => "character.noise", Self::Music => "type.music", Self::Ambience => "character.ambience", Self::Impact => "character.impact", Self::Foley => "character.foley", Self::Texture => "character.texture", } } } impl std::fmt::Display for SampleClass { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.as_str()) } } impl std::str::FromStr for SampleClass { type Err = (); fn from_str(s: &str) -> Result { match s { "kick" => Ok(Self::Kick), "snare" => Ok(Self::Snare), "hihat" => Ok(Self::HiHat), "cymbal" => Ok(Self::Cymbal), "clap" => Ok(Self::Clap), "tom" => Ok(Self::Tom), "percussion" => Ok(Self::Percussion), "bass" => Ok(Self::Bass), "guitar-bass" => Ok(Self::GuitarBass), "synth-bass" => Ok(Self::SynthBass), "sub-bass" => Ok(Self::SubBass), "vocal" => Ok(Self::Vocal), "vocal-chop" => Ok(Self::VocalChop), "vocal-phrase" => Ok(Self::VocalPhrase), "vocal-choir" => Ok(Self::VocalChoir), "synth" => Ok(Self::Synth), "synth-lead" => Ok(Self::SynthLead), "synth-stab" => Ok(Self::SynthStab), "synth-pluck" => Ok(Self::SynthPluck), "synth-chord" => Ok(Self::SynthChord), "pad" => Ok(Self::Pad), "misc" | "fx" => Ok(Self::Misc), "noise" => Ok(Self::Noise), "music" => Ok(Self::Music), "ambience" => Ok(Self::Ambience), "impact" => Ok(Self::Impact), "foley" => Ok(Self::Foley), "texture" => Ok(Self::Texture), _ => Err(()), } } } // ── Broad class (Layer 1) ── /// Broad classification for Layer 1 routing. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BroadClass { Drum, Bass, Vocal, Synth, Pad, Misc, Noise, Music, Ambience, Impact, Foley, Texture, } // ── Classification result ── /// Result of the two-layer classifier. #[derive(Debug, Clone)] pub struct ClassificationResult { pub class: SampleClass, pub confidence: f64, } // ── Decision tree types (custom format for embedded inference) ── /// A node in a serialized decision tree. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum TreeNode { Split { feature: usize, threshold: f64, left: Box, right: Box, }, Leaf { class: u8, }, } impl TreeNode { pub fn predict(&self, features: &[f64; NUM_FEATURES]) -> u8 { match self { TreeNode::Split { feature, threshold, left, right, } => { if *feature >= NUM_FEATURES { return 0; // fallback for malformed model } let val = features[*feature]; // NaN goes left (conservative path) instead of always right if val.is_nan() || val <= *threshold { left.predict(features) } else { right.predict(features) } } TreeNode::Leaf { class } => *class, } } } /// A trained Random Forest model (collection of decision trees). #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct RandomForestModel { pub trees: Vec, pub num_classes: u8, pub class_names: Vec, } // ── ClassifyInput ── /// Input bundle for classification — collects all features used by the classifier. pub struct ClassifyInput { pub duration: f64, pub centroid: f64, pub flatness: f64, pub zcr: f64, pub onset_strength: f64, pub bandwidth: f64, pub centroid_variance: f64, pub crest_factor: f64, pub attack_time: f64, pub mfcc_means: [f64; 13], pub mfcc_variances: [f64; 13], } impl ClassifyInput { /// Build from spectral features + waveform measurements (no MFCCs). pub fn new( features: &SpectralFeatures, duration: f64, crest_factor: f64, attack_time: f64, ) -> Self { Self::with_mfccs( features, duration, crest_factor, attack_time, &MfccFeatures::default(), ) } /// Build from spectral features + waveform measurements + MFCCs. pub fn with_mfccs( features: &SpectralFeatures, duration: f64, crest_factor: f64, attack_time: f64, mfccs: &MfccFeatures, ) -> Self { Self { duration, centroid: features.centroid, flatness: features.flatness, zcr: features.zero_crossing_rate, onset_strength: features.onset_strength, bandwidth: features.bandwidth, centroid_variance: features.centroid_variance, crest_factor, attack_time, mfcc_means: mfccs.means, mfcc_variances: mfccs.variances, } } /// Convert to a flat 35-element feature array for RF inference. /// /// Layout: [0-8] scalar features, [9-21] MFCC means, [22-34] MFCC variances. pub fn to_feature_array(&self) -> [f64; NUM_FEATURES] { let mut arr = [0.0; NUM_FEATURES]; arr[0] = self.duration; arr[1] = self.centroid; arr[2] = self.flatness; arr[3] = self.zcr; arr[4] = self.onset_strength; arr[5] = self.bandwidth; arr[6] = self.centroid_variance; arr[7] = self.crest_factor; arr[8] = self.attack_time; arr[9..22].copy_from_slice(&self.mfcc_means); arr[22..35].copy_from_slice(&self.mfcc_variances); arr } } // ── Model loading ── /// Layer 2 drum model embedded at compile time. const LAYER2_DRUM_BYTES: &[u8] = include_bytes!("../../models/layer2_drum.json"); const LAYER2_BASS_BYTES: &[u8] = include_bytes!("../../models/layer2_bass.json"); const LAYER2_VOCAL_BYTES: &[u8] = include_bytes!("../../models/layer2_vocal.json"); const LAYER2_SYNTH_BYTES: &[u8] = include_bytes!("../../models/layer2_synth.json"); /// Lazily deserialized Layer 2 models (one per broad class). static LAYER2_DRUM_MODEL: OnceLock = OnceLock::new(); static LAYER2_BASS_MODEL: OnceLock = OnceLock::new(); static LAYER2_VOCAL_MODEL: OnceLock = OnceLock::new(); static LAYER2_SYNTH_MODEL: OnceLock = OnceLock::new(); /// Fallback empty model — classification falls back to rule-based layer 1 only. fn empty_model() -> RandomForestModel { RandomForestModel { trees: vec![], num_classes: 0, class_names: vec![], } } fn layer2_model() -> &'static RandomForestModel { LAYER2_DRUM_MODEL.get_or_init(|| { serde_json::from_slice(LAYER2_DRUM_BYTES).unwrap_or_else(|e| { tracing::error!("Failed to deserialize embedded drum model: {e}"); empty_model() }) }) } fn layer2_bass_model() -> &'static RandomForestModel { LAYER2_BASS_MODEL.get_or_init(|| { serde_json::from_slice(LAYER2_BASS_BYTES).unwrap_or_else(|e| { tracing::error!("Failed to deserialize embedded bass model: {e}"); empty_model() }) }) } fn layer2_vocal_model() -> &'static RandomForestModel { LAYER2_VOCAL_MODEL.get_or_init(|| { serde_json::from_slice(LAYER2_VOCAL_BYTES).unwrap_or_else(|e| { tracing::error!("Failed to deserialize embedded vocal model: {e}"); empty_model() }) }) } fn layer2_synth_model() -> &'static RandomForestModel { LAYER2_SYNTH_MODEL.get_or_init(|| { serde_json::from_slice(LAYER2_SYNTH_BYTES).unwrap_or_else(|e| { tracing::error!("Failed to deserialize embedded synth model: {e}"); empty_model() }) }) } /// Class labels per Layer 2 model. Index corresponds to class ID in TreeNode::Leaf. const DRUM_CLASSES: [SampleClass; 7] = [ SampleClass::Kick, SampleClass::Snare, SampleClass::HiHat, SampleClass::Cymbal, SampleClass::Clap, SampleClass::Tom, SampleClass::Percussion, ]; const BASS_CLASSES: [SampleClass; 3] = [ SampleClass::GuitarBass, SampleClass::SynthBass, SampleClass::SubBass, ]; const VOCAL_CLASSES: [SampleClass; 3] = [ SampleClass::VocalChop, SampleClass::VocalPhrase, SampleClass::VocalChoir, ]; const SYNTH_CLASSES: [SampleClass; 4] = [ SampleClass::SynthLead, SampleClass::SynthStab, SampleClass::SynthPluck, SampleClass::SynthChord, ]; // ── Layer 1: Rule-based broad classifier ── /// Classify a sample into a broad category (Layer 1). /// /// Uses simplified rules from the original `classify_full()` to determine if a sample /// is a drum hit vs other content types. Drum detection heuristic: /// short duration + fast attack + high crest factor. fn classify_broad(input: &ClassifyInput) -> (BroadClass, f64) { let d = input.duration; let c = input.centroid; let flat = input.flatness; let zcr = input.zcr; let bw = input.bandwidth; let cv = input.centroid_variance; let crest = input.crest_factor; let attack = input.attack_time; // Noise: energy spread nearly uniformly — but not short percussive sounds // (cymbals/hihats can have high flatness but are still drums) if flat > 0.7 && (d > 2.0 || attack > 0.1) { return (BroadClass::Noise, 0.9); } // Bright metallic sounds (cymbals, crashes) are drums even when sustained if d < 10.0 && c > 3000.0 && flat > 0.2 { return (BroadClass::Drum, 0.85); } // Drum detection: short + percussive transient if d < 2.0 && (attack < 0.05 || crest > 2.5) { let mut conf: f64 = 0.8; if attack < 0.02 { conf += 0.05; } if crest > 4.0 { conf += 0.05; } if d < 1.0 { conf += 0.05; } return (BroadClass::Drum, conf.min(0.95)); } // Extended drum catch: short sounds with fast attack that didn't match above // (catches deeper kicks with low crest, toms with moderate attack) if d < 2.0 && attack < 0.1 && crest > 1.5 { return (BroadClass::Drum, 0.75); } // Impact: very sharp non-drum transient (must be longer than typical drums) if d > 1.0 && d < 5.0 && crest > 10.0 && attack < 0.005 { return (BroadClass::Impact, 0.85); } // Ambience: long, spectrally static, moderate noise if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 { return (BroadClass::Ambience, 0.85); } // Bass: low-frequency, tonal — but not short percussive sounds (kicks) if c < 400.0 && flat < 0.15 && d > 2.0 { return (BroadClass::Bass, 0.85); } // Vocal: mid-range, tonal, smooth waveform — require longer duration to avoid catching toms if d > 1.0 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 && crest < 2.0 { return (BroadClass::Vocal, 0.8); } // Pad: long, tonal, mid-range if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 { return (BroadClass::Pad, 0.8); } // Texture: long, spectrally evolving if d > 2.0 && cv > 500_000.0 { return (BroadClass::Texture, 0.8); } // Foley: broadband, moderate noise — require longer duration to avoid catching drums if d > 1.0 && bw > 2000.0 && flat > 0.1 && flat < 0.5 { return (BroadClass::Foley, 0.75); } // Synth: tonal, mid-to-high centroid — require low crest to exclude drums if c > 500.0 && flat < 0.3 && zcr < 0.1 && crest < 2.0 { return (BroadClass::Synth, 0.75); } // Music: long catch-all if d > 3.0 { return (BroadClass::Music, 0.6); } // Short unclassified sounds are likely drums if d < 2.0 { return (BroadClass::Drum, 0.6); } // Misc: final catch-all (nothing else matched) (BroadClass::Misc, 0.5) } /// Map a BroadClass (non-Drum) to the corresponding SampleClass. fn broad_to_sample_class(broad: BroadClass) -> SampleClass { match broad { BroadClass::Drum => SampleClass::Percussion, // shouldn't reach here BroadClass::Bass => SampleClass::Bass, BroadClass::Vocal => SampleClass::Vocal, BroadClass::Synth => SampleClass::Synth, BroadClass::Pad => SampleClass::Pad, BroadClass::Misc => SampleClass::Misc, BroadClass::Noise => SampleClass::Noise, BroadClass::Music => SampleClass::Music, BroadClass::Ambience => SampleClass::Ambience, BroadClass::Impact => SampleClass::Impact, BroadClass::Foley => SampleClass::Foley, BroadClass::Texture => SampleClass::Texture, } } // ── Layer 2: Random Forest inference ── /// Run Layer 2 classification using a model and its class mapping. /// /// Returns (SampleClass, confidence) where confidence is the vote fraction. /// If the model has no trees, returns the fallback class with 0 confidence. fn predict_with_model( model: &RandomForestModel, classes: &[SampleClass], fallback: SampleClass, features: &[f64; NUM_FEATURES], ) -> (SampleClass, f64) { if model.trees.is_empty() { return (fallback, 0.0); } let mut votes = vec![0u32; model.num_classes as usize]; for tree in &model.trees { let class_id = tree.predict(features) as usize; if class_id < votes.len() { votes[class_id] += 1; } } let total = model.trees.len() as f64; let (best_class_id, &best_count) = votes .iter() .enumerate() .max_by_key(|(_, count)| **count) .unwrap_or((0, &0)); let confidence = best_count as f64 / total; let class = classes .get(best_class_id) .copied() .unwrap_or(fallback); (class, confidence) } // ── Main entry point ── /// Two-layer ML classifier. /// /// Layer 1 (rule-based) determines broad class. Layer 2 (Random Forest) provides /// fine-grained sub-classification for Drum, Bass, Vocal, and Synth classes. /// If a Layer 2 model is not trained (empty trees), falls back to the broad class. #[instrument(skip_all)] pub fn classify_ml(input: &ClassifyInput) -> ClassificationResult { let drum_model = layer2_model(); // If no drum model at all, fall back entirely to rule-based if drum_model.trees.is_empty() { return ClassificationResult { class: classify_full(input), confidence: 0.0, }; } let (broad, broad_conf) = classify_broad(input); let features = input.to_feature_array(); match broad { BroadClass::Drum => { let (class, conf) = predict_with_model(drum_model, &DRUM_CLASSES, SampleClass::Percussion, &features); ClassificationResult { class, confidence: conf } } BroadClass::Bass => { let model = layer2_bass_model(); if model.trees.is_empty() { ClassificationResult { class: SampleClass::Bass, confidence: broad_conf } } else { let (class, conf) = predict_with_model(model, &BASS_CLASSES, SampleClass::Bass, &features); ClassificationResult { class, confidence: conf } } } BroadClass::Vocal => { let model = layer2_vocal_model(); if model.trees.is_empty() { ClassificationResult { class: SampleClass::Vocal, confidence: broad_conf } } else { let (class, conf) = predict_with_model(model, &VOCAL_CLASSES, SampleClass::Vocal, &features); ClassificationResult { class, confidence: conf } } } BroadClass::Synth => { let model = layer2_synth_model(); if model.trees.is_empty() { ClassificationResult { class: SampleClass::Synth, confidence: broad_conf } } else { let (class, conf) = predict_with_model(model, &SYNTH_CLASSES, SampleClass::Synth, &features); ClassificationResult { class, confidence: conf } } } _ => ClassificationResult { class: broad_to_sample_class(broad), confidence: broad_conf, }, } } // ── Smart skip helpers ── impl SampleClass { /// Whether BPM detection is meaningful for this sample type. /// /// Drums, impacts, noise, foley, ambience, and textures have no rhythmic /// content that BPM detection can usefully extract. Short chops also skip. pub fn has_rhythm(&self) -> bool { matches!( self, Self::Bass | Self::GuitarBass | Self::SynthBass | Self::SubBass | Self::Vocal | Self::VocalPhrase | Self::VocalChoir | Self::Synth | Self::SynthLead | Self::SynthChord | Self::Pad | Self::Music | Self::Misc ) } /// Whether key detection is meaningful for this sample type. /// /// Noise, foley, ambience, textures, and drum one-shots lack tonal content. /// Short chops and stabs are too brief for reliable key detection. pub fn has_pitch(&self) -> bool { matches!( self, Self::Bass | Self::GuitarBass | Self::SynthBass | Self::SubBass | Self::Vocal | Self::VocalPhrase | Self::VocalChoir | Self::Synth | Self::SynthLead | Self::SynthChord | Self::Pad | Self::Music | Self::Misc ) } } // ── Legacy functions ── /// Classify a sample using the full feature set (16 classes, priority-ordered decision tree). /// /// Uses threshold-based rules — no ML. Rules are evaluated in priority order. The tree is /// structured in two phases: drum/percussion rules first (for short sounds with percussive /// transients), then non-drum rules for sustained/tonal content. This prevents tonal drum /// hits from leaking into Vocal/Synth/Foley categories. /// /// The expanded feature set (crest factor, attack time, bandwidth, centroid variance) /// enables game audio categories (ambience, impact, foley, texture) that the original /// 4-feature classifier collapsed into Misc/Percussion/Music. #[instrument(skip_all)] pub fn classify_full(input: &ClassifyInput) -> SampleClass { let d = input.duration; let c = input.centroid; let flat = input.flatness; let zcr = input.zcr; let _onset = input.onset_strength; let bw = input.bandwidth; let cv = input.centroid_variance; let crest = input.crest_factor; let attack = input.attack_time; // 1. Noise: energy spread nearly uniformly across all frequencies if flat > 0.7 { return SampleClass::Noise; } // --- Phase 1: Drum/percussion classification (short sounds first) --- // 2. Kick: short, low-frequency dominant if d < 2.0 && c < 1200.0 && flat < 0.35 { return SampleClass::Kick; } // 3. HiHat: short, bright, metallic (relatively tonal) if d < 2.0 && c > 3500.0 && zcr > 0.1 && flat < 0.3 { return SampleClass::HiHat; } // 4. Cymbal: bright, noisy sustain (crashes, rides) if d > 0.5 && d < 10.0 && c > 3000.0 && flat > 0.25 { return SampleClass::Cymbal; } // 5. Snare: short, mid-to-high frequency, broadband noise from snare wires if d < 2.0 && c > 800.0 && flat > 0.2 && bw > 2000.0 { return SampleClass::Snare; } // 6. Percussion: short percussive catch-all if d < 2.0 && (attack < 0.05 || crest > 3.0) { return SampleClass::Percussion; } // --- Phase 2: Non-drum classification --- // 7. Impact: very sharp non-drum transient if d < 3.0 && crest > 10.0 && attack < 0.005 { return SampleClass::Impact; } // 8. Ambience: long, spectrally static, moderate noise if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 { return SampleClass::Ambience; } // 9. Bass: low-frequency, tonal if c < 400.0 && flat < 0.15 { return SampleClass::Bass; } // 10. Vocal: mid-range, tonal, smooth waveform, sustained if d > 0.5 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 { return SampleClass::Vocal; } // 11. Pad: long, tonal, mid-range if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 { return SampleClass::Pad; } // 12. Texture: long, spectrally evolving if d > 2.0 && cv > 500_000.0 { return SampleClass::Texture; } // 13. Foley: broadband, moderate noise if d > 0.1 && bw > 2000.0 && flat > 0.1 && flat < 0.5 { return SampleClass::Foley; } // 14. Synth: tonal, mid-to-high centroid, smooth waveform if c > 500.0 && flat < 0.3 && zcr < 0.1 { return SampleClass::Synth; } // 15. Percussion: short catch-all if d < 2.0 { return SampleClass::Percussion; } // 16. Music: long catch-all if d > 3.0 { return SampleClass::Music; } // 17. Misc: final catch-all (nothing else matched) SampleClass::Misc } /// Legacy classify function — uses only spectral features and duration. #[instrument(skip_all)] pub fn classify(features: &SpectralFeatures, duration: f64) -> Option { let input = ClassifyInput::new(features, duration, 0.0, 0.0); Some(classify_full(&input)) } #[cfg(test)] mod tests { use super::*; #[test] fn kick_classification() { let features = SpectralFeatures { centroid: 600.0, flatness: 0.15, rolloff: 1200.0, zero_crossing_rate: 0.04, onset_strength: 50.0, ..Default::default() }; assert_eq!(classify(&features, 0.3), Some(SampleClass::Kick)); } #[test] fn hihat_classification() { let features = SpectralFeatures { centroid: 5000.0, flatness: 0.2, rolloff: 12000.0, zero_crossing_rate: 0.15, onset_strength: 30.0, ..Default::default() }; assert_eq!(classify(&features, 0.2), Some(SampleClass::HiHat)); } #[test] fn snare_classification() { let input = ClassifyInput { duration: 0.5, centroid: 2500.0, flatness: 0.25, zcr: 0.12, onset_strength: 40.0, bandwidth: 3000.0, centroid_variance: 200_000.0, crest_factor: 6.0, attack_time: 0.003, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; assert_eq!(classify_full(&input), SampleClass::Snare); } #[test] fn noise_classification() { let features = SpectralFeatures { centroid: 5000.0, flatness: 0.75, rolloff: 15000.0, zero_crossing_rate: 0.3, onset_strength: 10.0, ..Default::default() }; assert_eq!(classify(&features, 2.0), Some(SampleClass::Noise)); } #[test] fn bass_classification() { let features = SpectralFeatures { centroid: 200.0, flatness: 0.08, rolloff: 500.0, zero_crossing_rate: 0.03, onset_strength: 20.0, ..Default::default() }; assert_eq!(classify(&features, 3.0), Some(SampleClass::Bass)); } #[test] fn tag_format() { assert_eq!(SampleClass::Kick.tag(), "instrument.drum.kick"); assert_eq!(SampleClass::Noise.tag(), "character.noise"); assert_eq!(SampleClass::Music.tag(), "type.music"); assert_eq!(SampleClass::Ambience.tag(), "character.ambience"); assert_eq!(SampleClass::Impact.tag(), "character.impact"); assert_eq!(SampleClass::Foley.tag(), "character.foley"); assert_eq!(SampleClass::Texture.tag(), "character.texture"); } #[test] fn impact_classification() { let input = ClassifyInput { duration: 2.5, centroid: 1500.0, flatness: 0.1, zcr: 0.06, onset_strength: 50.0, bandwidth: 2000.0, centroid_variance: 200_000.0, crest_factor: 12.0, attack_time: 0.002, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; assert_eq!(classify_full(&input), SampleClass::Impact); } #[test] fn ambience_classification() { let input = ClassifyInput { duration: 10.0, centroid: 2000.0, flatness: 0.3, zcr: 0.05, onset_strength: 5.0, bandwidth: 3000.0, centroid_variance: 50_000.0, crest_factor: 2.5, attack_time: 0.5, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; assert_eq!(classify_full(&input), SampleClass::Ambience); } #[test] fn foley_classification() { let input = ClassifyInput { duration: 3.0, centroid: 1500.0, flatness: 0.25, zcr: 0.06, onset_strength: 15.0, bandwidth: 4000.0, centroid_variance: 300_000.0, crest_factor: 2.0, attack_time: 0.1, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; assert_eq!(classify_full(&input), SampleClass::Foley); } #[test] fn texture_classification() { let input = ClassifyInput { duration: 5.0, centroid: 2000.0, flatness: 0.12, zcr: 0.12, onset_strength: 10.0, bandwidth: 3000.0, centroid_variance: 1_000_000.0, crest_factor: 3.0, attack_time: 0.1, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; assert_eq!(classify_full(&input), SampleClass::Texture); } #[test] fn round_trip_from_str() { for class in [ SampleClass::Kick, SampleClass::Snare, SampleClass::HiHat, SampleClass::Cymbal, SampleClass::Clap, SampleClass::Tom, SampleClass::Percussion, SampleClass::Bass, SampleClass::GuitarBass, SampleClass::SynthBass, SampleClass::SubBass, SampleClass::Vocal, SampleClass::VocalChop, SampleClass::VocalPhrase, SampleClass::VocalChoir, SampleClass::Synth, SampleClass::SynthLead, SampleClass::SynthStab, SampleClass::SynthPluck, SampleClass::SynthChord, SampleClass::Pad, SampleClass::Misc, SampleClass::Noise, SampleClass::Music, SampleClass::Ambience, SampleClass::Impact, SampleClass::Foley, SampleClass::Texture, ] { let s = class.as_str(); let parsed: SampleClass = s.parse().unwrap(); assert_eq!(parsed, class, "round-trip failed for {s}"); } } #[test] fn feature_array_layout() { let input = ClassifyInput { duration: 1.0, centroid: 2.0, flatness: 3.0, zcr: 4.0, onset_strength: 5.0, bandwidth: 6.0, centroid_variance: 7.0, crest_factor: 8.0, attack_time: 9.0, mfcc_means: [10.0; 13], mfcc_variances: [20.0; 13], }; let arr = input.to_feature_array(); assert_eq!(arr[0], 1.0); // duration assert_eq!(arr[1], 2.0); // centroid assert_eq!(arr[8], 9.0); // attack_time assert_eq!(arr[9], 10.0); // mfcc_mean[0] assert_eq!(arr[21], 10.0); // mfcc_mean[12] assert_eq!(arr[22], 20.0); // mfcc_var[0] assert_eq!(arr[34], 20.0); // mfcc_var[12] } #[test] fn classify_ml_returns_drum_for_kick_input() { let input = ClassifyInput { duration: 0.3, centroid: 600.0, flatness: 0.15, zcr: 0.04, onset_strength: 50.0, bandwidth: 500.0, centroid_variance: 10_000.0, crest_factor: 5.0, attack_time: 0.003, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; let result = classify_ml(&input); // Should classify as a drum class with nonzero confidence assert!( matches!( result.class, SampleClass::Kick | SampleClass::Snare | SampleClass::HiHat | SampleClass::Cymbal | SampleClass::Clap | SampleClass::Tom | SampleClass::Percussion ), "expected drum class, got {:?}", result.class ); } #[test] fn tree_node_prediction() { let tree = TreeNode::Split { feature: 1, threshold: 1000.0, left: Box::new(TreeNode::Leaf { class: 0 }), // kick (low centroid) right: Box::new(TreeNode::Leaf { class: 2 }), // hihat (high centroid) }; let mut features = [0.0; NUM_FEATURES]; features[1] = 500.0; // centroid below threshold assert_eq!(tree.predict(&features), 0); features[1] = 5000.0; // centroid above threshold assert_eq!(tree.predict(&features), 2); } #[test] fn broad_classifier_detects_drums() { let input = ClassifyInput { duration: 0.3, centroid: 600.0, flatness: 0.15, zcr: 0.04, onset_strength: 50.0, bandwidth: 500.0, centroid_variance: 10_000.0, crest_factor: 5.0, attack_time: 0.003, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; let (broad, conf) = classify_broad(&input); assert_eq!(broad, BroadClass::Drum); assert!(conf >= 0.8); } #[test] fn broad_classifier_detects_noise() { let input = ClassifyInput { duration: 2.0, centroid: 5000.0, flatness: 0.8, zcr: 0.3, onset_strength: 10.0, bandwidth: 5000.0, centroid_variance: 100_000.0, crest_factor: 1.5, attack_time: 0.5, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; let (broad, _) = classify_broad(&input); assert_eq!(broad, BroadClass::Noise); } #[test] fn smart_skip_drum_classes_skip_bpm_key() { // All drum sub-classes should skip BPM/key for class in [ SampleClass::Kick, SampleClass::Snare, SampleClass::HiHat, SampleClass::Cymbal, SampleClass::Clap, SampleClass::Tom, SampleClass::Percussion, ] { assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); assert!(!class.has_pitch(), "{class:?} should not have pitch"); } } #[test] fn smart_skip_tonal_classes_keep_bpm_key() { for class in [ SampleClass::Bass, SampleClass::GuitarBass, SampleClass::SynthBass, SampleClass::SubBass, SampleClass::Vocal, SampleClass::VocalPhrase, SampleClass::VocalChoir, SampleClass::Synth, SampleClass::SynthLead, SampleClass::SynthChord, SampleClass::Pad, SampleClass::Music, ] { assert!(class.has_rhythm(), "{class:?} should have rhythm"); assert!(class.has_pitch(), "{class:?} should have pitch"); } } #[test] fn smart_skip_short_classes_skip_bpm_key() { // Short chops and stabs are too brief for BPM/key detection for class in [ SampleClass::VocalChop, SampleClass::SynthStab, SampleClass::SynthPluck, ] { assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); assert!(!class.has_pitch(), "{class:?} should not have pitch"); } } #[test] fn smart_skip_non_tonal_non_drum_skip() { for class in [ SampleClass::Noise, SampleClass::Ambience, SampleClass::Impact, SampleClass::Foley, SampleClass::Texture, ] { assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); assert!(!class.has_pitch(), "{class:?} should not have pitch"); } } #[test] fn smart_skip_pad_has_pitch_but_no_rhythm() { // Pads are tonal (key applies) and sustained (could have rhythm) // Current design: both true for pads assert!(SampleClass::Pad.has_pitch()); assert!(SampleClass::Pad.has_rhythm()); } #[test] fn classify_ml_bass_routes_to_sub_class() { // Bass-like input: low centroid, tonal, sustained let input = ClassifyInput { duration: 3.0, centroid: 200.0, flatness: 0.08, zcr: 0.03, onset_strength: 20.0, bandwidth: 400.0, centroid_variance: 50_000.0, crest_factor: 1.5, attack_time: 0.05, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; let result = classify_ml(&input); // With trained bass model, should classify as a bass sub-class assert!( matches!( result.class, SampleClass::GuitarBass | SampleClass::SynthBass | SampleClass::SubBass ), "expected bass sub-class, got {:?}", result.class ); assert!(result.confidence > 0.0); } #[test] fn classify_ml_vocal_fallback_without_model() { // Vocal-like input: mid-range, tonal, smooth, sustained let input = ClassifyInput { duration: 2.0, centroid: 1000.0, flatness: 0.1, zcr: 0.05, onset_strength: 15.0, bandwidth: 1500.0, centroid_variance: 80_000.0, crest_factor: 1.5, attack_time: 0.2, mfcc_means: [0.0; 13], mfcc_variances: [0.0; 13], }; let result = classify_ml(&input); assert_eq!(result.class, SampleClass::Vocal); } #[test] fn predict_with_model_returns_fallback_for_empty() { let empty_model = RandomForestModel { trees: vec![], num_classes: 3, class_names: vec!["a".into(), "b".into(), "c".into()], }; let features = [0.0; NUM_FEATURES]; let (class, conf) = predict_with_model( &empty_model, &BASS_CLASSES, SampleClass::Bass, &features, ); assert_eq!(class, SampleClass::Bass); assert_eq!(conf, 0.0); } }