| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
|
| 9 |
use std::sync::OnceLock; |
| 10 |
|
| 11 |
use super::mfcc::MfccFeatures; |
| 12 |
use super::spectral::SpectralFeatures; |
| 13 |
use tracing::instrument; |
| 14 |
|
| 15 |
|
| 16 |
pub const NUM_FEATURES: usize = 35; |
| 17 |
|
| 18 |
|
| 19 |
|
| 20 |
|
| 21 |
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] |
| 22 |
pub enum SampleClass { |
| 23 |
Kick, |
| 24 |
Snare, |
| 25 |
HiHat, |
| 26 |
Cymbal, |
| 27 |
Clap, |
| 28 |
Tom, |
| 29 |
Percussion, |
| 30 |
Bass, |
| 31 |
GuitarBass, |
| 32 |
SynthBass, |
| 33 |
SubBass, |
| 34 |
Vocal, |
| 35 |
VocalChop, |
| 36 |
VocalPhrase, |
| 37 |
VocalChoir, |
| 38 |
Synth, |
| 39 |
SynthLead, |
| 40 |
SynthStab, |
| 41 |
SynthPluck, |
| 42 |
SynthChord, |
| 43 |
Pad, |
| 44 |
Misc, |
| 45 |
Noise, |
| 46 |
Music, |
| 47 |
Ambience, |
| 48 |
Impact, |
| 49 |
Foley, |
| 50 |
Texture, |
| 51 |
} |
| 52 |
|
| 53 |
impl SampleClass { |
| 54 |
pub fn as_str(&self) -> &'static str { |
| 55 |
match self { |
| 56 |
Self::Kick => "kick", |
| 57 |
Self::Snare => "snare", |
| 58 |
Self::HiHat => "hihat", |
| 59 |
Self::Cymbal => "cymbal", |
| 60 |
Self::Clap => "clap", |
| 61 |
Self::Tom => "tom", |
| 62 |
Self::Percussion => "percussion", |
| 63 |
Self::Bass => "bass", |
| 64 |
Self::GuitarBass => "guitar-bass", |
| 65 |
Self::SynthBass => "synth-bass", |
| 66 |
Self::SubBass => "sub-bass", |
| 67 |
Self::Vocal => "vocal", |
| 68 |
Self::VocalChop => "vocal-chop", |
| 69 |
Self::VocalPhrase => "vocal-phrase", |
| 70 |
Self::VocalChoir => "vocal-choir", |
| 71 |
Self::Synth => "synth", |
| 72 |
Self::SynthLead => "synth-lead", |
| 73 |
Self::SynthStab => "synth-stab", |
| 74 |
Self::SynthPluck => "synth-pluck", |
| 75 |
Self::SynthChord => "synth-chord", |
| 76 |
Self::Pad => "pad", |
| 77 |
Self::Misc => "misc", |
| 78 |
Self::Noise => "noise", |
| 79 |
Self::Music => "music", |
| 80 |
Self::Ambience => "ambience", |
| 81 |
Self::Impact => "impact", |
| 82 |
Self::Foley => "foley", |
| 83 |
Self::Texture => "texture", |
| 84 |
} |
| 85 |
} |
| 86 |
|
| 87 |
|
| 88 |
pub fn tag(&self) -> &'static str { |
| 89 |
match self { |
| 90 |
Self::Kick => "instrument.drum.kick", |
| 91 |
Self::Snare => "instrument.drum.snare", |
| 92 |
Self::HiHat => "instrument.drum.hihat", |
| 93 |
Self::Cymbal => "instrument.drum.cymbal", |
| 94 |
Self::Clap => "instrument.drum.clap", |
| 95 |
Self::Tom => "instrument.drum.tom", |
| 96 |
Self::Percussion => "instrument.percussion", |
| 97 |
Self::Bass => "instrument.bass", |
| 98 |
Self::GuitarBass => "instrument.bass.guitar", |
| 99 |
Self::SynthBass => "instrument.bass.synth", |
| 100 |
Self::SubBass => "instrument.bass.sub", |
| 101 |
Self::Vocal => "instrument.vocal", |
| 102 |
Self::VocalChop => "instrument.vocal.chop", |
| 103 |
Self::VocalPhrase => "instrument.vocal.phrase", |
| 104 |
Self::VocalChoir => "instrument.vocal.choir", |
| 105 |
Self::Synth => "instrument.synth", |
| 106 |
Self::SynthLead => "instrument.synth.lead", |
| 107 |
Self::SynthStab => "instrument.synth.stab", |
| 108 |
Self::SynthPluck => "instrument.synth.pluck", |
| 109 |
Self::SynthChord => "instrument.synth.chord", |
| 110 |
Self::Pad => "instrument.pad", |
| 111 |
Self::Misc => "character.misc", |
| 112 |
Self::Noise => "character.noise", |
| 113 |
Self::Music => "type.music", |
| 114 |
Self::Ambience => "character.ambience", |
| 115 |
Self::Impact => "character.impact", |
| 116 |
Self::Foley => "character.foley", |
| 117 |
Self::Texture => "character.texture", |
| 118 |
} |
| 119 |
} |
| 120 |
} |
| 121 |
|
| 122 |
impl std::fmt::Display for SampleClass { |
| 123 |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 124 |
write!(f, "{}", self.as_str()) |
| 125 |
} |
| 126 |
} |
| 127 |
|
| 128 |
impl std::str::FromStr for SampleClass { |
| 129 |
type Err = (); |
| 130 |
|
| 131 |
fn from_str(s: &str) -> Result<Self, Self::Err> { |
| 132 |
match s { |
| 133 |
"kick" => Ok(Self::Kick), |
| 134 |
"snare" => Ok(Self::Snare), |
| 135 |
"hihat" => Ok(Self::HiHat), |
| 136 |
"cymbal" => Ok(Self::Cymbal), |
| 137 |
"clap" => Ok(Self::Clap), |
| 138 |
"tom" => Ok(Self::Tom), |
| 139 |
"percussion" => Ok(Self::Percussion), |
| 140 |
"bass" => Ok(Self::Bass), |
| 141 |
"guitar-bass" => Ok(Self::GuitarBass), |
| 142 |
"synth-bass" => Ok(Self::SynthBass), |
| 143 |
"sub-bass" => Ok(Self::SubBass), |
| 144 |
"vocal" => Ok(Self::Vocal), |
| 145 |
"vocal-chop" => Ok(Self::VocalChop), |
| 146 |
"vocal-phrase" => Ok(Self::VocalPhrase), |
| 147 |
"vocal-choir" => Ok(Self::VocalChoir), |
| 148 |
"synth" => Ok(Self::Synth), |
| 149 |
"synth-lead" => Ok(Self::SynthLead), |
| 150 |
"synth-stab" => Ok(Self::SynthStab), |
| 151 |
"synth-pluck" => Ok(Self::SynthPluck), |
| 152 |
"synth-chord" => Ok(Self::SynthChord), |
| 153 |
"pad" => Ok(Self::Pad), |
| 154 |
"misc" | "fx" => Ok(Self::Misc), |
| 155 |
"noise" => Ok(Self::Noise), |
| 156 |
"music" => Ok(Self::Music), |
| 157 |
"ambience" => Ok(Self::Ambience), |
| 158 |
"impact" => Ok(Self::Impact), |
| 159 |
"foley" => Ok(Self::Foley), |
| 160 |
"texture" => Ok(Self::Texture), |
| 161 |
_ => Err(()), |
| 162 |
} |
| 163 |
} |
| 164 |
} |
| 165 |
|
| 166 |
|
| 167 |
|
| 168 |
|
| 169 |
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 170 |
pub enum BroadClass { |
| 171 |
Drum, |
| 172 |
Bass, |
| 173 |
Vocal, |
| 174 |
Synth, |
| 175 |
Pad, |
| 176 |
Misc, |
| 177 |
Noise, |
| 178 |
Music, |
| 179 |
Ambience, |
| 180 |
Impact, |
| 181 |
Foley, |
| 182 |
Texture, |
| 183 |
} |
| 184 |
|
| 185 |
|
| 186 |
|
| 187 |
|
| 188 |
#[derive(Debug, Clone)] |
| 189 |
pub struct ClassificationResult { |
| 190 |
pub class: SampleClass, |
| 191 |
pub confidence: f64, |
| 192 |
} |
| 193 |
|
| 194 |
|
| 195 |
|
| 196 |
|
| 197 |
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] |
| 198 |
pub enum TreeNode { |
| 199 |
Split { |
| 200 |
feature: usize, |
| 201 |
threshold: f64, |
| 202 |
left: Box<TreeNode>, |
| 203 |
right: Box<TreeNode>, |
| 204 |
}, |
| 205 |
Leaf { |
| 206 |
class: u8, |
| 207 |
}, |
| 208 |
} |
| 209 |
|
| 210 |
impl TreeNode { |
| 211 |
pub fn predict(&self, features: &[f64; NUM_FEATURES]) -> u8 { |
| 212 |
match self { |
| 213 |
TreeNode::Split { |
| 214 |
feature, |
| 215 |
threshold, |
| 216 |
left, |
| 217 |
right, |
| 218 |
} => { |
| 219 |
if *feature >= NUM_FEATURES { |
| 220 |
return 0; |
| 221 |
} |
| 222 |
let val = features[*feature]; |
| 223 |
|
| 224 |
if val.is_nan() || val <= *threshold { |
| 225 |
left.predict(features) |
| 226 |
} else { |
| 227 |
right.predict(features) |
| 228 |
} |
| 229 |
} |
| 230 |
TreeNode::Leaf { class } => *class, |
| 231 |
} |
| 232 |
} |
| 233 |
} |
| 234 |
|
| 235 |
|
| 236 |
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] |
| 237 |
pub struct RandomForestModel { |
| 238 |
pub trees: Vec<TreeNode>, |
| 239 |
pub num_classes: u8, |
| 240 |
pub class_names: Vec<String>, |
| 241 |
} |
| 242 |
|
| 243 |
|
| 244 |
|
| 245 |
|
| 246 |
pub struct ClassifyInput { |
| 247 |
pub duration: f64, |
| 248 |
pub centroid: f64, |
| 249 |
pub flatness: f64, |
| 250 |
pub zcr: f64, |
| 251 |
pub onset_strength: f64, |
| 252 |
pub bandwidth: f64, |
| 253 |
pub centroid_variance: f64, |
| 254 |
pub crest_factor: f64, |
| 255 |
pub attack_time: f64, |
| 256 |
pub mfcc_means: [f64; 13], |
| 257 |
pub mfcc_variances: [f64; 13], |
| 258 |
} |
| 259 |
|
| 260 |
impl ClassifyInput { |
| 261 |
|
| 262 |
pub fn new( |
| 263 |
features: &SpectralFeatures, |
| 264 |
duration: f64, |
| 265 |
crest_factor: f64, |
| 266 |
attack_time: f64, |
| 267 |
) -> Self { |
| 268 |
Self::with_mfccs( |
| 269 |
features, |
| 270 |
duration, |
| 271 |
crest_factor, |
| 272 |
attack_time, |
| 273 |
&MfccFeatures::default(), |
| 274 |
) |
| 275 |
} |
| 276 |
|
| 277 |
|
| 278 |
pub fn with_mfccs( |
| 279 |
features: &SpectralFeatures, |
| 280 |
duration: f64, |
| 281 |
crest_factor: f64, |
| 282 |
attack_time: f64, |
| 283 |
mfccs: &MfccFeatures, |
| 284 |
) -> Self { |
| 285 |
Self { |
| 286 |
duration, |
| 287 |
centroid: features.centroid, |
| 288 |
flatness: features.flatness, |
| 289 |
zcr: features.zero_crossing_rate, |
| 290 |
onset_strength: features.onset_strength, |
| 291 |
bandwidth: features.bandwidth, |
| 292 |
centroid_variance: features.centroid_variance, |
| 293 |
crest_factor, |
| 294 |
attack_time, |
| 295 |
mfcc_means: mfccs.means, |
| 296 |
mfcc_variances: mfccs.variances, |
| 297 |
} |
| 298 |
} |
| 299 |
|
| 300 |
|
| 301 |
|
| 302 |
|
| 303 |
pub fn to_feature_array(&self) -> [f64; NUM_FEATURES] { |
| 304 |
let mut arr = [0.0; NUM_FEATURES]; |
| 305 |
arr[0] = self.duration; |
| 306 |
arr[1] = self.centroid; |
| 307 |
arr[2] = self.flatness; |
| 308 |
arr[3] = self.zcr; |
| 309 |
arr[4] = self.onset_strength; |
| 310 |
arr[5] = self.bandwidth; |
| 311 |
arr[6] = self.centroid_variance; |
| 312 |
arr[7] = self.crest_factor; |
| 313 |
arr[8] = self.attack_time; |
| 314 |
arr[9..22].copy_from_slice(&self.mfcc_means); |
| 315 |
arr[22..35].copy_from_slice(&self.mfcc_variances); |
| 316 |
arr |
| 317 |
} |
| 318 |
} |
| 319 |
|
| 320 |
|
| 321 |
|
| 322 |
|
| 323 |
const LAYER2_DRUM_BYTES: &[u8] = include_bytes!("../../models/layer2_drum.json"); |
| 324 |
const LAYER2_BASS_BYTES: &[u8] = include_bytes!("../../models/layer2_bass.json"); |
| 325 |
const LAYER2_VOCAL_BYTES: &[u8] = include_bytes!("../../models/layer2_vocal.json"); |
| 326 |
const LAYER2_SYNTH_BYTES: &[u8] = include_bytes!("../../models/layer2_synth.json"); |
| 327 |
|
| 328 |
|
| 329 |
static LAYER2_DRUM_MODEL: OnceLock<RandomForestModel> = OnceLock::new(); |
| 330 |
static LAYER2_BASS_MODEL: OnceLock<RandomForestModel> = OnceLock::new(); |
| 331 |
static LAYER2_VOCAL_MODEL: OnceLock<RandomForestModel> = OnceLock::new(); |
| 332 |
static LAYER2_SYNTH_MODEL: OnceLock<RandomForestModel> = OnceLock::new(); |
| 333 |
|
| 334 |
|
| 335 |
fn empty_model() -> RandomForestModel { |
| 336 |
RandomForestModel { |
| 337 |
trees: vec![], |
| 338 |
num_classes: 0, |
| 339 |
class_names: vec![], |
| 340 |
} |
| 341 |
} |
| 342 |
|
| 343 |
fn layer2_model() -> &'static RandomForestModel { |
| 344 |
LAYER2_DRUM_MODEL.get_or_init(|| { |
| 345 |
serde_json::from_slice(LAYER2_DRUM_BYTES).unwrap_or_else(|e| { |
| 346 |
tracing::error!("Failed to deserialize embedded drum model: {e}"); |
| 347 |
empty_model() |
| 348 |
}) |
| 349 |
}) |
| 350 |
} |
| 351 |
|
| 352 |
fn layer2_bass_model() -> &'static RandomForestModel { |
| 353 |
LAYER2_BASS_MODEL.get_or_init(|| { |
| 354 |
serde_json::from_slice(LAYER2_BASS_BYTES).unwrap_or_else(|e| { |
| 355 |
tracing::error!("Failed to deserialize embedded bass model: {e}"); |
| 356 |
empty_model() |
| 357 |
}) |
| 358 |
}) |
| 359 |
} |
| 360 |
|
| 361 |
fn layer2_vocal_model() -> &'static RandomForestModel { |
| 362 |
LAYER2_VOCAL_MODEL.get_or_init(|| { |
| 363 |
serde_json::from_slice(LAYER2_VOCAL_BYTES).unwrap_or_else(|e| { |
| 364 |
tracing::error!("Failed to deserialize embedded vocal model: {e}"); |
| 365 |
empty_model() |
| 366 |
}) |
| 367 |
}) |
| 368 |
} |
| 369 |
|
| 370 |
fn layer2_synth_model() -> &'static RandomForestModel { |
| 371 |
LAYER2_SYNTH_MODEL.get_or_init(|| { |
| 372 |
serde_json::from_slice(LAYER2_SYNTH_BYTES).unwrap_or_else(|e| { |
| 373 |
tracing::error!("Failed to deserialize embedded synth model: {e}"); |
| 374 |
empty_model() |
| 375 |
}) |
| 376 |
}) |
| 377 |
} |
| 378 |
|
| 379 |
|
| 380 |
const DRUM_CLASSES: [SampleClass; 7] = [ |
| 381 |
SampleClass::Kick, |
| 382 |
SampleClass::Snare, |
| 383 |
SampleClass::HiHat, |
| 384 |
SampleClass::Cymbal, |
| 385 |
SampleClass::Clap, |
| 386 |
SampleClass::Tom, |
| 387 |
SampleClass::Percussion, |
| 388 |
]; |
| 389 |
|
| 390 |
const BASS_CLASSES: [SampleClass; 3] = [ |
| 391 |
SampleClass::GuitarBass, |
| 392 |
SampleClass::SynthBass, |
| 393 |
SampleClass::SubBass, |
| 394 |
]; |
| 395 |
|
| 396 |
const VOCAL_CLASSES: [SampleClass; 3] = [ |
| 397 |
SampleClass::VocalChop, |
| 398 |
SampleClass::VocalPhrase, |
| 399 |
SampleClass::VocalChoir, |
| 400 |
]; |
| 401 |
|
| 402 |
const SYNTH_CLASSES: [SampleClass; 4] = [ |
| 403 |
SampleClass::SynthLead, |
| 404 |
SampleClass::SynthStab, |
| 405 |
SampleClass::SynthPluck, |
| 406 |
SampleClass::SynthChord, |
| 407 |
]; |
| 408 |
|
| 409 |
|
| 410 |
|
| 411 |
|
| 412 |
|
| 413 |
|
| 414 |
|
| 415 |
|
| 416 |
fn classify_broad(input: &ClassifyInput) -> (BroadClass, f64) { |
| 417 |
let d = input.duration; |
| 418 |
let c = input.centroid; |
| 419 |
let flat = input.flatness; |
| 420 |
let zcr = input.zcr; |
| 421 |
let bw = input.bandwidth; |
| 422 |
let cv = input.centroid_variance; |
| 423 |
let crest = input.crest_factor; |
| 424 |
let attack = input.attack_time; |
| 425 |
|
| 426 |
|
| 427 |
|
| 428 |
if flat > 0.7 && (d > 2.0 || attack > 0.1) { |
| 429 |
return (BroadClass::Noise, 0.9); |
| 430 |
} |
| 431 |
|
| 432 |
|
| 433 |
if d < 10.0 && c > 3000.0 && flat > 0.2 { |
| 434 |
return (BroadClass::Drum, 0.85); |
| 435 |
} |
| 436 |
|
| 437 |
|
| 438 |
if d < 2.0 && (attack < 0.05 || crest > 2.5) { |
| 439 |
let mut conf: f64 = 0.8; |
| 440 |
if attack < 0.02 { |
| 441 |
conf += 0.05; |
| 442 |
} |
| 443 |
if crest > 4.0 { |
| 444 |
conf += 0.05; |
| 445 |
} |
| 446 |
if d < 1.0 { |
| 447 |
conf += 0.05; |
| 448 |
} |
| 449 |
return (BroadClass::Drum, conf.min(0.95)); |
| 450 |
} |
| 451 |
|
| 452 |
|
| 453 |
|
| 454 |
if d < 2.0 && attack < 0.1 && crest > 1.5 { |
| 455 |
return (BroadClass::Drum, 0.75); |
| 456 |
} |
| 457 |
|
| 458 |
|
| 459 |
if d > 1.0 && d < 5.0 && crest > 10.0 && attack < 0.005 { |
| 460 |
return (BroadClass::Impact, 0.85); |
| 461 |
} |
| 462 |
|
| 463 |
|
| 464 |
if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 { |
| 465 |
return (BroadClass::Ambience, 0.85); |
| 466 |
} |
| 467 |
|
| 468 |
|
| 469 |
if c < 400.0 && flat < 0.15 && d > 2.0 { |
| 470 |
return (BroadClass::Bass, 0.85); |
| 471 |
} |
| 472 |
|
| 473 |
|
| 474 |
if d > 1.0 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 && crest < 2.0 { |
| 475 |
return (BroadClass::Vocal, 0.8); |
| 476 |
} |
| 477 |
|
| 478 |
|
| 479 |
if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 { |
| 480 |
return (BroadClass::Pad, 0.8); |
| 481 |
} |
| 482 |
|
| 483 |
|
| 484 |
if d > 2.0 && cv > 500_000.0 { |
| 485 |
return (BroadClass::Texture, 0.8); |
| 486 |
} |
| 487 |
|
| 488 |
|
| 489 |
if d > 1.0 && bw > 2000.0 && flat > 0.1 && flat < 0.5 { |
| 490 |
return (BroadClass::Foley, 0.75); |
| 491 |
} |
| 492 |
|
| 493 |
|
| 494 |
if c > 500.0 && flat < 0.3 && zcr < 0.1 && crest < 2.0 { |
| 495 |
return (BroadClass::Synth, 0.75); |
| 496 |
} |
| 497 |
|
| 498 |
|
| 499 |
if d > 3.0 { |
| 500 |
return (BroadClass::Music, 0.6); |
| 501 |
} |
| 502 |
|
| 503 |
|
| 504 |
if d < 2.0 { |
| 505 |
return (BroadClass::Drum, 0.6); |
| 506 |
} |
| 507 |
|
| 508 |
|
| 509 |
(BroadClass::Misc, 0.5) |
| 510 |
} |
| 511 |
|
| 512 |
|
| 513 |
fn broad_to_sample_class(broad: BroadClass) -> SampleClass { |
| 514 |
match broad { |
| 515 |
BroadClass::Drum => SampleClass::Percussion, |
| 516 |
BroadClass::Bass => SampleClass::Bass, |
| 517 |
BroadClass::Vocal => SampleClass::Vocal, |
| 518 |
BroadClass::Synth => SampleClass::Synth, |
| 519 |
BroadClass::Pad => SampleClass::Pad, |
| 520 |
BroadClass::Misc => SampleClass::Misc, |
| 521 |
BroadClass::Noise => SampleClass::Noise, |
| 522 |
BroadClass::Music => SampleClass::Music, |
| 523 |
BroadClass::Ambience => SampleClass::Ambience, |
| 524 |
BroadClass::Impact => SampleClass::Impact, |
| 525 |
BroadClass::Foley => SampleClass::Foley, |
| 526 |
BroadClass::Texture => SampleClass::Texture, |
| 527 |
} |
| 528 |
} |
| 529 |
|
| 530 |
|
| 531 |
|
| 532 |
|
| 533 |
|
| 534 |
|
| 535 |
|
| 536 |
fn predict_with_model( |
| 537 |
model: &RandomForestModel, |
| 538 |
classes: &[SampleClass], |
| 539 |
fallback: SampleClass, |
| 540 |
features: &[f64; NUM_FEATURES], |
| 541 |
) -> (SampleClass, f64) { |
| 542 |
if model.trees.is_empty() { |
| 543 |
return (fallback, 0.0); |
| 544 |
} |
| 545 |
|
| 546 |
let mut votes = vec![0u32; model.num_classes as usize]; |
| 547 |
for tree in &model.trees { |
| 548 |
let class_id = tree.predict(features) as usize; |
| 549 |
if class_id < votes.len() { |
| 550 |
votes[class_id] += 1; |
| 551 |
} |
| 552 |
} |
| 553 |
|
| 554 |
let total = model.trees.len() as f64; |
| 555 |
let (best_class_id, &best_count) = votes |
| 556 |
.iter() |
| 557 |
.enumerate() |
| 558 |
.max_by_key(|(_, count)| **count) |
| 559 |
.unwrap_or((0, &0)); |
| 560 |
|
| 561 |
let confidence = best_count as f64 / total; |
| 562 |
let class = classes |
| 563 |
.get(best_class_id) |
| 564 |
.copied() |
| 565 |
.unwrap_or(fallback); |
| 566 |
|
| 567 |
(class, confidence) |
| 568 |
} |
| 569 |
|
| 570 |
|
| 571 |
|
| 572 |
|
| 573 |
|
| 574 |
|
| 575 |
|
| 576 |
|
| 577 |
#[instrument(skip_all)] |
| 578 |
pub fn classify_ml(input: &ClassifyInput) -> ClassificationResult { |
| 579 |
let drum_model = layer2_model(); |
| 580 |
|
| 581 |
|
| 582 |
if drum_model.trees.is_empty() { |
| 583 |
return ClassificationResult { |
| 584 |
class: classify_full(input), |
| 585 |
confidence: 0.0, |
| 586 |
}; |
| 587 |
} |
| 588 |
|
| 589 |
let (broad, broad_conf) = classify_broad(input); |
| 590 |
let features = input.to_feature_array(); |
| 591 |
|
| 592 |
match broad { |
| 593 |
BroadClass::Drum => { |
| 594 |
let (class, conf) = predict_with_model(drum_model, &DRUM_CLASSES, SampleClass::Percussion, &features); |
| 595 |
ClassificationResult { class, confidence: conf } |
| 596 |
} |
| 597 |
BroadClass::Bass => { |
| 598 |
let model = layer2_bass_model(); |
| 599 |
if model.trees.is_empty() { |
| 600 |
ClassificationResult { class: SampleClass::Bass, confidence: broad_conf } |
| 601 |
} else { |
| 602 |
let (class, conf) = predict_with_model(model, &BASS_CLASSES, SampleClass::Bass, &features); |
| 603 |
ClassificationResult { class, confidence: conf } |
| 604 |
} |
| 605 |
} |
| 606 |
BroadClass::Vocal => { |
| 607 |
let model = layer2_vocal_model(); |
| 608 |
if model.trees.is_empty() { |
| 609 |
ClassificationResult { class: SampleClass::Vocal, confidence: broad_conf } |
| 610 |
} else { |
| 611 |
let (class, conf) = predict_with_model(model, &VOCAL_CLASSES, SampleClass::Vocal, &features); |
| 612 |
ClassificationResult { class, confidence: conf } |
| 613 |
} |
| 614 |
} |
| 615 |
BroadClass::Synth => { |
| 616 |
let model = layer2_synth_model(); |
| 617 |
if model.trees.is_empty() { |
| 618 |
ClassificationResult { class: SampleClass::Synth, confidence: broad_conf } |
| 619 |
} else { |
| 620 |
let (class, conf) = predict_with_model(model, &SYNTH_CLASSES, SampleClass::Synth, &features); |
| 621 |
ClassificationResult { class, confidence: conf } |
| 622 |
} |
| 623 |
} |
| 624 |
_ => ClassificationResult { |
| 625 |
class: broad_to_sample_class(broad), |
| 626 |
confidence: broad_conf, |
| 627 |
}, |
| 628 |
} |
| 629 |
} |
| 630 |
|
| 631 |
|
| 632 |
|
| 633 |
impl SampleClass { |
| 634 |
|
| 635 |
|
| 636 |
|
| 637 |
|
| 638 |
pub fn has_rhythm(&self) -> bool { |
| 639 |
matches!( |
| 640 |
self, |
| 641 |
Self::Bass |
| 642 |
| Self::GuitarBass |
| 643 |
| Self::SynthBass |
| 644 |
| Self::SubBass |
| 645 |
| Self::Vocal |
| 646 |
| Self::VocalPhrase |
| 647 |
| Self::VocalChoir |
| 648 |
| Self::Synth |
| 649 |
| Self::SynthLead |
| 650 |
| Self::SynthChord |
| 651 |
| Self::Pad |
| 652 |
| Self::Music |
| 653 |
| Self::Misc |
| 654 |
) |
| 655 |
} |
| 656 |
|
| 657 |
|
| 658 |
|
| 659 |
|
| 660 |
|
| 661 |
pub fn has_pitch(&self) -> bool { |
| 662 |
matches!( |
| 663 |
self, |
| 664 |
Self::Bass |
| 665 |
| Self::GuitarBass |
| 666 |
| Self::SynthBass |
| 667 |
| Self::SubBass |
| 668 |
| Self::Vocal |
| 669 |
| Self::VocalPhrase |
| 670 |
| Self::VocalChoir |
| 671 |
| Self::Synth |
| 672 |
| Self::SynthLead |
| 673 |
| Self::SynthChord |
| 674 |
| Self::Pad |
| 675 |
| Self::Music |
| 676 |
| Self::Misc |
| 677 |
) |
| 678 |
} |
| 679 |
} |
| 680 |
|
| 681 |
|
| 682 |
|
| 683 |
|
| 684 |
|
| 685 |
|
| 686 |
|
| 687 |
|
| 688 |
|
| 689 |
|
| 690 |
|
| 691 |
|
| 692 |
|
| 693 |
#[instrument(skip_all)] |
| 694 |
pub fn classify_full(input: &ClassifyInput) -> SampleClass { |
| 695 |
let d = input.duration; |
| 696 |
let c = input.centroid; |
| 697 |
let flat = input.flatness; |
| 698 |
let zcr = input.zcr; |
| 699 |
let _onset = input.onset_strength; |
| 700 |
let bw = input.bandwidth; |
| 701 |
let cv = input.centroid_variance; |
| 702 |
let crest = input.crest_factor; |
| 703 |
let attack = input.attack_time; |
| 704 |
|
| 705 |
|
| 706 |
if flat > 0.7 { |
| 707 |
return SampleClass::Noise; |
| 708 |
} |
| 709 |
|
| 710 |
|
| 711 |
|
| 712 |
|
| 713 |
if d < 2.0 && c < 1200.0 && flat < 0.35 { |
| 714 |
return SampleClass::Kick; |
| 715 |
} |
| 716 |
|
| 717 |
|
| 718 |
if d < 2.0 && c > 3500.0 && zcr > 0.1 && flat < 0.3 { |
| 719 |
return SampleClass::HiHat; |
| 720 |
} |
| 721 |
|
| 722 |
|
| 723 |
if d > 0.5 && d < 10.0 && c > 3000.0 && flat > 0.25 { |
| 724 |
return SampleClass::Cymbal; |
| 725 |
} |
| 726 |
|
| 727 |
|
| 728 |
if d < 2.0 && c > 800.0 && flat > 0.2 && bw > 2000.0 { |
| 729 |
return SampleClass::Snare; |
| 730 |
} |
| 731 |
|
| 732 |
|
| 733 |
if d < 2.0 && (attack < 0.05 || crest > 3.0) { |
| 734 |
return SampleClass::Percussion; |
| 735 |
} |
| 736 |
|
| 737 |
|
| 738 |
|
| 739 |
|
| 740 |
if d < 3.0 && crest > 10.0 && attack < 0.005 { |
| 741 |
return SampleClass::Impact; |
| 742 |
} |
| 743 |
|
| 744 |
|
| 745 |
if d > 5.0 && cv < 100_000.0 && flat > 0.15 && flat < 0.5 { |
| 746 |
return SampleClass::Ambience; |
| 747 |
} |
| 748 |
|
| 749 |
|
| 750 |
if c < 400.0 && flat < 0.15 { |
| 751 |
return SampleClass::Bass; |
| 752 |
} |
| 753 |
|
| 754 |
|
| 755 |
if d > 0.5 && c > 300.0 && c < 3000.0 && flat < 0.2 && zcr < 0.08 { |
| 756 |
return SampleClass::Vocal; |
| 757 |
} |
| 758 |
|
| 759 |
|
| 760 |
if d > 2.0 && flat < 0.2 && c > 200.0 && c < 2000.0 { |
| 761 |
return SampleClass::Pad; |
| 762 |
} |
| 763 |
|
| 764 |
|
| 765 |
if d > 2.0 && cv > 500_000.0 { |
| 766 |
return SampleClass::Texture; |
| 767 |
} |
| 768 |
|
| 769 |
|
| 770 |
if d > 0.1 && bw > 2000.0 && flat > 0.1 && flat < 0.5 { |
| 771 |
return SampleClass::Foley; |
| 772 |
} |
| 773 |
|
| 774 |
|
| 775 |
if c > 500.0 && flat < 0.3 && zcr < 0.1 { |
| 776 |
return SampleClass::Synth; |
| 777 |
} |
| 778 |
|
| 779 |
|
| 780 |
if d < 2.0 { |
| 781 |
return SampleClass::Percussion; |
| 782 |
} |
| 783 |
|
| 784 |
|
| 785 |
if d > 3.0 { |
| 786 |
return SampleClass::Music; |
| 787 |
} |
| 788 |
|
| 789 |
|
| 790 |
SampleClass::Misc |
| 791 |
} |
| 792 |
|
| 793 |
|
| 794 |
#[instrument(skip_all)] |
| 795 |
pub fn classify(features: &SpectralFeatures, duration: f64) -> Option<SampleClass> { |
| 796 |
let input = ClassifyInput::new(features, duration, 0.0, 0.0); |
| 797 |
Some(classify_full(&input)) |
| 798 |
} |
| 799 |
|
| 800 |
#[cfg(test)] |
| 801 |
mod tests { |
| 802 |
use super::*; |
| 803 |
|
| 804 |
#[test] |
| 805 |
fn kick_classification() { |
| 806 |
let features = SpectralFeatures { |
| 807 |
centroid: 600.0, |
| 808 |
flatness: 0.15, |
| 809 |
rolloff: 1200.0, |
| 810 |
zero_crossing_rate: 0.04, |
| 811 |
onset_strength: 50.0, |
| 812 |
..Default::default() |
| 813 |
}; |
| 814 |
assert_eq!(classify(&features, 0.3), Some(SampleClass::Kick)); |
| 815 |
} |
| 816 |
|
| 817 |
#[test] |
| 818 |
fn hihat_classification() { |
| 819 |
let features = SpectralFeatures { |
| 820 |
centroid: 5000.0, |
| 821 |
flatness: 0.2, |
| 822 |
rolloff: 12000.0, |
| 823 |
zero_crossing_rate: 0.15, |
| 824 |
onset_strength: 30.0, |
| 825 |
..Default::default() |
| 826 |
}; |
| 827 |
assert_eq!(classify(&features, 0.2), Some(SampleClass::HiHat)); |
| 828 |
} |
| 829 |
|
| 830 |
#[test] |
| 831 |
fn snare_classification() { |
| 832 |
let input = ClassifyInput { |
| 833 |
duration: 0.5, |
| 834 |
centroid: 2500.0, |
| 835 |
flatness: 0.25, |
| 836 |
zcr: 0.12, |
| 837 |
onset_strength: 40.0, |
| 838 |
bandwidth: 3000.0, |
| 839 |
centroid_variance: 200_000.0, |
| 840 |
crest_factor: 6.0, |
| 841 |
attack_time: 0.003, |
| 842 |
mfcc_means: [0.0; 13], |
| 843 |
mfcc_variances: [0.0; 13], |
| 844 |
}; |
| 845 |
assert_eq!(classify_full(&input), SampleClass::Snare); |
| 846 |
} |
| 847 |
|
| 848 |
#[test] |
| 849 |
fn noise_classification() { |
| 850 |
let features = SpectralFeatures { |
| 851 |
centroid: 5000.0, |
| 852 |
flatness: 0.75, |
| 853 |
rolloff: 15000.0, |
| 854 |
zero_crossing_rate: 0.3, |
| 855 |
onset_strength: 10.0, |
| 856 |
..Default::default() |
| 857 |
}; |
| 858 |
assert_eq!(classify(&features, 2.0), Some(SampleClass::Noise)); |
| 859 |
} |
| 860 |
|
| 861 |
#[test] |
| 862 |
fn bass_classification() { |
| 863 |
let features = SpectralFeatures { |
| 864 |
centroid: 200.0, |
| 865 |
flatness: 0.08, |
| 866 |
rolloff: 500.0, |
| 867 |
zero_crossing_rate: 0.03, |
| 868 |
onset_strength: 20.0, |
| 869 |
..Default::default() |
| 870 |
}; |
| 871 |
assert_eq!(classify(&features, 3.0), Some(SampleClass::Bass)); |
| 872 |
} |
| 873 |
|
| 874 |
#[test] |
| 875 |
fn tag_format() { |
| 876 |
assert_eq!(SampleClass::Kick.tag(), "instrument.drum.kick"); |
| 877 |
assert_eq!(SampleClass::Noise.tag(), "character.noise"); |
| 878 |
assert_eq!(SampleClass::Music.tag(), "type.music"); |
| 879 |
assert_eq!(SampleClass::Ambience.tag(), "character.ambience"); |
| 880 |
assert_eq!(SampleClass::Impact.tag(), "character.impact"); |
| 881 |
assert_eq!(SampleClass::Foley.tag(), "character.foley"); |
| 882 |
assert_eq!(SampleClass::Texture.tag(), "character.texture"); |
| 883 |
} |
| 884 |
|
| 885 |
#[test] |
| 886 |
fn impact_classification() { |
| 887 |
let input = ClassifyInput { |
| 888 |
duration: 2.5, |
| 889 |
centroid: 1500.0, |
| 890 |
flatness: 0.1, |
| 891 |
zcr: 0.06, |
| 892 |
onset_strength: 50.0, |
| 893 |
bandwidth: 2000.0, |
| 894 |
centroid_variance: 200_000.0, |
| 895 |
crest_factor: 12.0, |
| 896 |
attack_time: 0.002, |
| 897 |
mfcc_means: [0.0; 13], |
| 898 |
mfcc_variances: [0.0; 13], |
| 899 |
}; |
| 900 |
assert_eq!(classify_full(&input), SampleClass::Impact); |
| 901 |
} |
| 902 |
|
| 903 |
#[test] |
| 904 |
fn ambience_classification() { |
| 905 |
let input = ClassifyInput { |
| 906 |
duration: 10.0, |
| 907 |
centroid: 2000.0, |
| 908 |
flatness: 0.3, |
| 909 |
zcr: 0.05, |
| 910 |
onset_strength: 5.0, |
| 911 |
bandwidth: 3000.0, |
| 912 |
centroid_variance: 50_000.0, |
| 913 |
crest_factor: 2.5, |
| 914 |
attack_time: 0.5, |
| 915 |
mfcc_means: [0.0; 13], |
| 916 |
mfcc_variances: [0.0; 13], |
| 917 |
}; |
| 918 |
assert_eq!(classify_full(&input), SampleClass::Ambience); |
| 919 |
} |
| 920 |
|
| 921 |
#[test] |
| 922 |
fn foley_classification() { |
| 923 |
let input = ClassifyInput { |
| 924 |
duration: 3.0, |
| 925 |
centroid: 1500.0, |
| 926 |
flatness: 0.25, |
| 927 |
zcr: 0.06, |
| 928 |
onset_strength: 15.0, |
| 929 |
bandwidth: 4000.0, |
| 930 |
centroid_variance: 300_000.0, |
| 931 |
crest_factor: 2.0, |
| 932 |
attack_time: 0.1, |
| 933 |
mfcc_means: [0.0; 13], |
| 934 |
mfcc_variances: [0.0; 13], |
| 935 |
}; |
| 936 |
assert_eq!(classify_full(&input), SampleClass::Foley); |
| 937 |
} |
| 938 |
|
| 939 |
#[test] |
| 940 |
fn texture_classification() { |
| 941 |
let input = ClassifyInput { |
| 942 |
duration: 5.0, |
| 943 |
centroid: 2000.0, |
| 944 |
flatness: 0.12, |
| 945 |
zcr: 0.12, |
| 946 |
onset_strength: 10.0, |
| 947 |
bandwidth: 3000.0, |
| 948 |
centroid_variance: 1_000_000.0, |
| 949 |
crest_factor: 3.0, |
| 950 |
attack_time: 0.1, |
| 951 |
mfcc_means: [0.0; 13], |
| 952 |
mfcc_variances: [0.0; 13], |
| 953 |
}; |
| 954 |
assert_eq!(classify_full(&input), SampleClass::Texture); |
| 955 |
} |
| 956 |
|
| 957 |
#[test] |
| 958 |
fn round_trip_from_str() { |
| 959 |
for class in [ |
| 960 |
SampleClass::Kick, |
| 961 |
SampleClass::Snare, |
| 962 |
SampleClass::HiHat, |
| 963 |
SampleClass::Cymbal, |
| 964 |
SampleClass::Clap, |
| 965 |
SampleClass::Tom, |
| 966 |
SampleClass::Percussion, |
| 967 |
SampleClass::Bass, |
| 968 |
SampleClass::GuitarBass, |
| 969 |
SampleClass::SynthBass, |
| 970 |
SampleClass::SubBass, |
| 971 |
SampleClass::Vocal, |
| 972 |
SampleClass::VocalChop, |
| 973 |
SampleClass::VocalPhrase, |
| 974 |
SampleClass::VocalChoir, |
| 975 |
SampleClass::Synth, |
| 976 |
SampleClass::SynthLead, |
| 977 |
SampleClass::SynthStab, |
| 978 |
SampleClass::SynthPluck, |
| 979 |
SampleClass::SynthChord, |
| 980 |
SampleClass::Pad, |
| 981 |
SampleClass::Misc, |
| 982 |
SampleClass::Noise, |
| 983 |
SampleClass::Music, |
| 984 |
SampleClass::Ambience, |
| 985 |
SampleClass::Impact, |
| 986 |
SampleClass::Foley, |
| 987 |
SampleClass::Texture, |
| 988 |
] { |
| 989 |
let s = class.as_str(); |
| 990 |
let parsed: SampleClass = s.parse().unwrap(); |
| 991 |
assert_eq!(parsed, class, "round-trip failed for {s}"); |
| 992 |
} |
| 993 |
} |
| 994 |
|
| 995 |
#[test] |
| 996 |
fn feature_array_layout() { |
| 997 |
let input = ClassifyInput { |
| 998 |
duration: 1.0, |
| 999 |
centroid: 2.0, |
| 1000 |
flatness: 3.0, |
| 1001 |
zcr: 4.0, |
| 1002 |
onset_strength: 5.0, |
| 1003 |
bandwidth: 6.0, |
| 1004 |
centroid_variance: 7.0, |
| 1005 |
crest_factor: 8.0, |
| 1006 |
attack_time: 9.0, |
| 1007 |
mfcc_means: [10.0; 13], |
| 1008 |
mfcc_variances: [20.0; 13], |
| 1009 |
}; |
| 1010 |
let arr = input.to_feature_array(); |
| 1011 |
assert_eq!(arr[0], 1.0); |
| 1012 |
assert_eq!(arr[1], 2.0); |
| 1013 |
assert_eq!(arr[8], 9.0); |
| 1014 |
assert_eq!(arr[9], 10.0); |
| 1015 |
assert_eq!(arr[21], 10.0); |
| 1016 |
assert_eq!(arr[22], 20.0); |
| 1017 |
assert_eq!(arr[34], 20.0); |
| 1018 |
} |
| 1019 |
|
| 1020 |
#[test] |
| 1021 |
fn classify_ml_returns_drum_for_kick_input() { |
| 1022 |
let input = ClassifyInput { |
| 1023 |
duration: 0.3, |
| 1024 |
centroid: 600.0, |
| 1025 |
flatness: 0.15, |
| 1026 |
zcr: 0.04, |
| 1027 |
onset_strength: 50.0, |
| 1028 |
bandwidth: 500.0, |
| 1029 |
centroid_variance: 10_000.0, |
| 1030 |
crest_factor: 5.0, |
| 1031 |
attack_time: 0.003, |
| 1032 |
mfcc_means: [0.0; 13], |
| 1033 |
mfcc_variances: [0.0; 13], |
| 1034 |
}; |
| 1035 |
let result = classify_ml(&input); |
| 1036 |
|
| 1037 |
assert!( |
| 1038 |
matches!( |
| 1039 |
result.class, |
| 1040 |
SampleClass::Kick |
| 1041 |
| SampleClass::Snare |
| 1042 |
| SampleClass::HiHat |
| 1043 |
| SampleClass::Cymbal |
| 1044 |
| SampleClass::Clap |
| 1045 |
| SampleClass::Tom |
| 1046 |
| SampleClass::Percussion |
| 1047 |
), |
| 1048 |
"expected drum class, got {:?}", |
| 1049 |
result.class |
| 1050 |
); |
| 1051 |
} |
| 1052 |
|
| 1053 |
#[test] |
| 1054 |
fn tree_node_prediction() { |
| 1055 |
let tree = TreeNode::Split { |
| 1056 |
feature: 1, |
| 1057 |
threshold: 1000.0, |
| 1058 |
left: Box::new(TreeNode::Leaf { class: 0 }), |
| 1059 |
right: Box::new(TreeNode::Leaf { class: 2 }), |
| 1060 |
}; |
| 1061 |
let mut features = [0.0; NUM_FEATURES]; |
| 1062 |
features[1] = 500.0; |
| 1063 |
assert_eq!(tree.predict(&features), 0); |
| 1064 |
features[1] = 5000.0; |
| 1065 |
assert_eq!(tree.predict(&features), 2); |
| 1066 |
} |
| 1067 |
|
| 1068 |
#[test] |
| 1069 |
fn broad_classifier_detects_drums() { |
| 1070 |
let input = ClassifyInput { |
| 1071 |
duration: 0.3, |
| 1072 |
centroid: 600.0, |
| 1073 |
flatness: 0.15, |
| 1074 |
zcr: 0.04, |
| 1075 |
onset_strength: 50.0, |
| 1076 |
bandwidth: 500.0, |
| 1077 |
centroid_variance: 10_000.0, |
| 1078 |
crest_factor: 5.0, |
| 1079 |
attack_time: 0.003, |
| 1080 |
mfcc_means: [0.0; 13], |
| 1081 |
mfcc_variances: [0.0; 13], |
| 1082 |
}; |
| 1083 |
let (broad, conf) = classify_broad(&input); |
| 1084 |
assert_eq!(broad, BroadClass::Drum); |
| 1085 |
assert!(conf >= 0.8); |
| 1086 |
} |
| 1087 |
|
| 1088 |
#[test] |
| 1089 |
fn broad_classifier_detects_noise() { |
| 1090 |
let input = ClassifyInput { |
| 1091 |
duration: 2.0, |
| 1092 |
centroid: 5000.0, |
| 1093 |
flatness: 0.8, |
| 1094 |
zcr: 0.3, |
| 1095 |
onset_strength: 10.0, |
| 1096 |
bandwidth: 5000.0, |
| 1097 |
centroid_variance: 100_000.0, |
| 1098 |
crest_factor: 1.5, |
| 1099 |
attack_time: 0.5, |
| 1100 |
mfcc_means: [0.0; 13], |
| 1101 |
mfcc_variances: [0.0; 13], |
| 1102 |
}; |
| 1103 |
let (broad, _) = classify_broad(&input); |
| 1104 |
assert_eq!(broad, BroadClass::Noise); |
| 1105 |
} |
| 1106 |
|
| 1107 |
#[test] |
| 1108 |
fn smart_skip_drum_classes_skip_bpm_key() { |
| 1109 |
|
| 1110 |
for class in [ |
| 1111 |
SampleClass::Kick, |
| 1112 |
SampleClass::Snare, |
| 1113 |
SampleClass::HiHat, |
| 1114 |
SampleClass::Cymbal, |
| 1115 |
SampleClass::Clap, |
| 1116 |
SampleClass::Tom, |
| 1117 |
SampleClass::Percussion, |
| 1118 |
] { |
| 1119 |
assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); |
| 1120 |
assert!(!class.has_pitch(), "{class:?} should not have pitch"); |
| 1121 |
} |
| 1122 |
} |
| 1123 |
|
| 1124 |
#[test] |
| 1125 |
fn smart_skip_tonal_classes_keep_bpm_key() { |
| 1126 |
for class in [ |
| 1127 |
SampleClass::Bass, |
| 1128 |
SampleClass::GuitarBass, |
| 1129 |
SampleClass::SynthBass, |
| 1130 |
SampleClass::SubBass, |
| 1131 |
SampleClass::Vocal, |
| 1132 |
SampleClass::VocalPhrase, |
| 1133 |
SampleClass::VocalChoir, |
| 1134 |
SampleClass::Synth, |
| 1135 |
SampleClass::SynthLead, |
| 1136 |
SampleClass::SynthChord, |
| 1137 |
SampleClass::Pad, |
| 1138 |
SampleClass::Music, |
| 1139 |
] { |
| 1140 |
assert!(class.has_rhythm(), "{class:?} should have rhythm"); |
| 1141 |
assert!(class.has_pitch(), "{class:?} should have pitch"); |
| 1142 |
} |
| 1143 |
} |
| 1144 |
|
| 1145 |
#[test] |
| 1146 |
fn smart_skip_short_classes_skip_bpm_key() { |
| 1147 |
|
| 1148 |
for class in [ |
| 1149 |
SampleClass::VocalChop, |
| 1150 |
SampleClass::SynthStab, |
| 1151 |
SampleClass::SynthPluck, |
| 1152 |
] { |
| 1153 |
assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); |
| 1154 |
assert!(!class.has_pitch(), "{class:?} should not have pitch"); |
| 1155 |
} |
| 1156 |
} |
| 1157 |
|
| 1158 |
#[test] |
| 1159 |
fn smart_skip_non_tonal_non_drum_skip() { |
| 1160 |
for class in [ |
| 1161 |
SampleClass::Noise, |
| 1162 |
SampleClass::Ambience, |
| 1163 |
SampleClass::Impact, |
| 1164 |
SampleClass::Foley, |
| 1165 |
SampleClass::Texture, |
| 1166 |
] { |
| 1167 |
assert!(!class.has_rhythm(), "{class:?} should not have rhythm"); |
| 1168 |
assert!(!class.has_pitch(), "{class:?} should not have pitch"); |
| 1169 |
} |
| 1170 |
} |
| 1171 |
|
| 1172 |
#[test] |
| 1173 |
fn smart_skip_pad_has_pitch_but_no_rhythm() { |
| 1174 |
|
| 1175 |
|
| 1176 |
assert!(SampleClass::Pad.has_pitch()); |
| 1177 |
assert!(SampleClass::Pad.has_rhythm()); |
| 1178 |
} |
| 1179 |
|
| 1180 |
#[test] |
| 1181 |
fn classify_ml_bass_routes_to_sub_class() { |
| 1182 |
|
| 1183 |
let input = ClassifyInput { |
| 1184 |
duration: 3.0, |
| 1185 |
centroid: 200.0, |
| 1186 |
flatness: 0.08, |
| 1187 |
zcr: 0.03, |
| 1188 |
onset_strength: 20.0, |
| 1189 |
bandwidth: 400.0, |
| 1190 |
centroid_variance: 50_000.0, |
| 1191 |
crest_factor: 1.5, |
| 1192 |
attack_time: 0.05, |
| 1193 |
mfcc_means: [0.0; 13], |
| 1194 |
mfcc_variances: [0.0; 13], |
| 1195 |
}; |
| 1196 |
let result = classify_ml(&input); |
| 1197 |
|
| 1198 |
assert!( |
| 1199 |
matches!( |
| 1200 |
result.class, |
| 1201 |
SampleClass::GuitarBass | SampleClass::SynthBass | SampleClass::SubBass |
| 1202 |
), |
| 1203 |
"expected bass sub-class, got {:?}", |
| 1204 |
result.class |
| 1205 |
); |
| 1206 |
assert!(result.confidence > 0.0); |
| 1207 |
} |
| 1208 |
|
| 1209 |
#[test] |
| 1210 |
fn classify_ml_vocal_fallback_without_model() { |
| 1211 |
|
| 1212 |
let input = ClassifyInput { |
| 1213 |
duration: 2.0, |
| 1214 |
centroid: 1000.0, |
| 1215 |
flatness: 0.1, |
| 1216 |
zcr: 0.05, |
| 1217 |
onset_strength: 15.0, |
| 1218 |
bandwidth: 1500.0, |
| 1219 |
centroid_variance: 80_000.0, |
| 1220 |
crest_factor: 1.5, |
| 1221 |
attack_time: 0.2, |
| 1222 |
mfcc_means: [0.0; 13], |
| 1223 |
mfcc_variances: [0.0; 13], |
| 1224 |
}; |
| 1225 |
let result = classify_ml(&input); |
| 1226 |
assert_eq!(result.class, SampleClass::Vocal); |
| 1227 |
} |
| 1228 |
|
| 1229 |
#[test] |
| 1230 |
fn predict_with_model_returns_fallback_for_empty() { |
| 1231 |
let empty_model = RandomForestModel { |
| 1232 |
trees: vec![], |
| 1233 |
num_classes: 3, |
| 1234 |
class_names: vec!["a".into(), "b".into(), "c".into()], |
| 1235 |
}; |
| 1236 |
let features = [0.0; NUM_FEATURES]; |
| 1237 |
let (class, conf) = predict_with_model( |
| 1238 |
&empty_model, |
| 1239 |
&BASS_CLASSES, |
| 1240 |
SampleClass::Bass, |
| 1241 |
&features, |
| 1242 |
); |
| 1243 |
assert_eq!(class, SampleClass::Bass); |
| 1244 |
assert_eq!(conf, 0.0); |
| 1245 |
} |
| 1246 |
} |
| 1247 |
|