Skip to main content

max / audiofiles

21.4 KB · 681 lines History Blame Raw
1 //! Random Forest training binary for audiofiles Layer 2 classification.
2 //!
3 //! Supports training separate models for drum, bass, vocal, and synth sub-classes.
4 //! Scans labeled samples from a training directory, extracts 35-feature vectors,
5 //! trains a Random Forest, runs 5-fold stratified cross-validation, and serializes
6 //! the model to the appropriate JSON file.
7 //!
8 //! Usage:
9 //! cargo run -p audiofiles-train -- --target drum [/path/to/data]
10 //! cargo run -p audiofiles-train -- --target bass [/path/to/data]
11 //! cargo run -p audiofiles-train -- --target vocal [/path/to/data]
12 //! cargo run -p audiofiles-train -- --target synth [/path/to/data]
13
14 use std::collections::HashMap;
15 use std::path::PathBuf;
16
17 use audiofiles_core::analysis::classify::{RandomForestModel, TreeNode, NUM_FEATURES};
18 use audiofiles_core::analysis::config::AnalysisConfig;
19 use audiofiles_core::analysis::{self, mfcc, spectral};
20 use rand::seq::SliceRandom;
21 use rand::{Rng, SeedableRng};
22 use rayon::prelude::*;
23
24 // ── Target configuration ──
25
26 /// Per-target training config: class names, label→ID mapping, output filename.
27 struct TargetConfig {
28 name: &'static str,
29 class_names: Vec<&'static str>,
30 model_filename: &'static str,
31 }
32
33 fn target_config(target: &str) -> TargetConfig {
34 match target {
35 "drum" => TargetConfig {
36 name: "drum",
37 class_names: vec!["kick", "snare", "hihat", "cymbal", "clap", "tom", "percussion"],
38 model_filename: "layer2_drum.json",
39 },
40 "bass" => TargetConfig {
41 name: "bass",
42 class_names: vec!["guitar-bass", "synth-bass", "sub-bass"],
43 model_filename: "layer2_bass.json",
44 },
45 "vocal" => TargetConfig {
46 name: "vocal",
47 class_names: vec!["vocal-chop", "vocal-phrase", "vocal-choir"],
48 model_filename: "layer2_vocal.json",
49 },
50 "synth" => TargetConfig {
51 name: "synth",
52 class_names: vec!["synth-lead", "synth-stab", "synth-pluck", "synth-chord"],
53 model_filename: "layer2_synth.json",
54 },
55 _ => {
56 eprintln!("Unknown target: {target}");
57 eprintln!("Valid targets: drum, bass, vocal, synth");
58 std::process::exit(1);
59 }
60 }
61 }
62
63 /// Map a folder name to a class ID for the given target config.
64 fn label_to_class(label: &str, config: &TargetConfig) -> Option<u8> {
65 config
66 .class_names
67 .iter()
68 .position(|&name| name == label)
69 .map(|i| i as u8)
70 }
71
72 // ── Feature extraction ──
73
74 /// Extract 35-feature vector from an audio file.
75 fn extract_features(path: &std::path::Path) -> Option<[f64; NUM_FEATURES]> {
76 let config = AnalysisConfig {
77 loudness: true,
78 spectral: true,
79 bpm: false,
80 key: false,
81 loop_detect: false,
82 classify: false,
83 fingerprint: false,
84 auto_suggest_tags: false,
85 max_analysis_seconds: Some(5.0),
86 smart_skip: false,
87 };
88
89 let decoded = analysis::decode::decode_to_mono(path).ok()?;
90
91 let capped: &[f32] = if let Some(max_secs) = config.max_analysis_seconds {
92 let max_samples = (max_secs * decoded.sample_rate as f64) as usize;
93 &decoded.samples[..decoded.samples.len().min(max_samples)]
94 } else {
95 &decoded.samples
96 };
97
98 let crest_factor = analysis::basic::crest_factor(&decoded.samples);
99 let attack_time = analysis::basic::attack_time(&decoded.samples, decoded.sample_rate);
100
101 let (features, magnitude_frames) =
102 spectral::compute_spectral_features_with_frames(capped, decoded.sample_rate);
103
104 let mfcc_features = mfcc::compute_mfccs(&magnitude_frames, decoded.sample_rate, 1024);
105
106 let input = audiofiles_core::analysis::classify::ClassifyInput::with_mfccs(
107 &features,
108 decoded.duration,
109 crest_factor,
110 attack_time,
111 &mfcc_features,
112 );
113
114 Some(input.to_feature_array())
115 }
116
117 // ── Random Forest training ──
118
119 /// Gini impurity for a set of class counts.
120 fn gini_impurity(counts: &[u32], total: u32) -> f64 {
121 if total == 0 {
122 return 0.0;
123 }
124 let t = total as f64;
125 1.0 - counts.iter().map(|&c| (c as f64 / t).powi(2)).sum::<f64>()
126 }
127
128 /// Find the majority class in a label slice.
129 fn majority_class(labels: &[u8], num_classes: usize) -> u8 {
130 let mut counts = vec![0u32; num_classes];
131 for &l in labels {
132 if (l as usize) < num_classes {
133 counts[l as usize] += 1;
134 }
135 }
136 counts
137 .iter()
138 .enumerate()
139 .max_by_key(|(_, &c)| c)
140 .map(|(i, _)| i as u8)
141 .unwrap_or(0)
142 }
143
144 /// Train a single decision tree on a dataset with random feature selection.
145 fn train_tree(
146 data: &[[f64; NUM_FEATURES]],
147 labels: &[u8],
148 num_classes: usize,
149 max_depth: usize,
150 min_leaf: usize,
151 features_per_split: usize,
152 rng: &mut impl Rng,
153 ) -> TreeNode {
154 train_tree_recursive(data, labels, num_classes, max_depth, min_leaf, features_per_split, rng)
155 }
156
157 fn train_tree_recursive(
158 data: &[[f64; NUM_FEATURES]],
159 labels: &[u8],
160 num_classes: usize,
161 depth: usize,
162 min_leaf: usize,
163 features_per_split: usize,
164 rng: &mut impl Rng,
165 ) -> TreeNode {
166 if depth == 0 || data.len() <= min_leaf {
167 return TreeNode::Leaf {
168 class: majority_class(labels, num_classes),
169 };
170 }
171
172 let first = labels[0];
173 if labels.iter().all(|&l| l == first) {
174 return TreeNode::Leaf { class: first };
175 }
176
177 let mut all_features: Vec<usize> = (0..NUM_FEATURES).collect();
178 all_features.shuffle(rng);
179 let candidates = &all_features[..features_per_split.min(NUM_FEATURES)];
180
181 let mut best_gini = f64::MAX;
182 let mut best_feature = 0;
183 let mut best_threshold = 0.0;
184
185 for &feat_idx in candidates {
186 let mut values: Vec<f64> = data.iter().map(|row| row[feat_idx]).collect();
187 values.sort_by(|a, b| a.total_cmp(b));
188 values.dedup();
189
190 if values.len() < 2 {
191 continue;
192 }
193
194 let max_thresholds = 200;
195 let step = if values.len() > max_thresholds + 1 {
196 values.len() / max_thresholds
197 } else {
198 1
199 };
200
201 let mut i = 0;
202 while i + 1 < values.len() {
203 let threshold = (values[i] + values[i + 1]) / 2.0;
204
205 let mut left_counts = vec![0u32; num_classes];
206 let mut right_counts = vec![0u32; num_classes];
207 let mut left_total = 0u32;
208 let mut right_total = 0u32;
209
210 for (row, &label) in data.iter().zip(labels.iter()) {
211 if row[feat_idx] <= threshold {
212 left_counts[label as usize] += 1;
213 left_total += 1;
214 } else {
215 right_counts[label as usize] += 1;
216 right_total += 1;
217 }
218 }
219
220 if left_total < min_leaf as u32 || right_total < min_leaf as u32 {
221 i += step;
222 continue;
223 }
224
225 let total = (left_total + right_total) as f64;
226 let weighted_gini = (left_total as f64 / total)
227 * gini_impurity(&left_counts, left_total)
228 + (right_total as f64 / total)
229 * gini_impurity(&right_counts, right_total);
230
231 if weighted_gini < best_gini {
232 best_gini = weighted_gini;
233 best_feature = feat_idx;
234 best_threshold = threshold;
235 }
236
237 i += step;
238 }
239 }
240
241 if best_gini == f64::MAX {
242 return TreeNode::Leaf {
243 class: majority_class(labels, num_classes),
244 };
245 }
246
247 let mut left_data = Vec::new();
248 let mut left_labels = Vec::new();
249 let mut right_data = Vec::new();
250 let mut right_labels = Vec::new();
251
252 for (row, &label) in data.iter().zip(labels.iter()) {
253 if row[best_feature] <= best_threshold {
254 left_data.push(*row);
255 left_labels.push(label);
256 } else {
257 right_data.push(*row);
258 right_labels.push(label);
259 }
260 }
261
262 TreeNode::Split {
263 feature: best_feature,
264 threshold: best_threshold,
265 left: Box::new(train_tree_recursive(
266 &left_data,
267 &left_labels,
268 num_classes,
269 depth - 1,
270 min_leaf,
271 features_per_split,
272 rng,
273 )),
274 right: Box::new(train_tree_recursive(
275 &right_data,
276 &right_labels,
277 num_classes,
278 depth - 1,
279 min_leaf,
280 features_per_split,
281 rng,
282 )),
283 }
284 }
285
286 /// Train a Random Forest: bootstrap sampling + random feature selection per tree.
287 fn train_random_forest(
288 data: &[[f64; NUM_FEATURES]],
289 labels: &[u8],
290 num_classes: usize,
291 n_trees: usize,
292 max_depth: usize,
293 min_leaf: usize,
294 ) -> Vec<TreeNode> {
295 let features_per_split = (NUM_FEATURES as f64).sqrt() as usize; // sqrt(35) ≈ 6
296
297 (0..n_trees)
298 .into_par_iter()
299 .map(|i| {
300 let mut rng = rand::rngs::StdRng::seed_from_u64(42 + i as u64);
301 let n = data.len();
302
303 let mut boot_data = Vec::with_capacity(n);
304 let mut boot_labels = Vec::with_capacity(n);
305 for _ in 0..n {
306 let idx = rng.random_range(0..n);
307 boot_data.push(data[idx]);
308 boot_labels.push(labels[idx]);
309 }
310
311 train_tree(
312 &boot_data,
313 &boot_labels,
314 num_classes,
315 max_depth,
316 min_leaf,
317 features_per_split,
318 &mut rng,
319 )
320 })
321 .collect()
322 }
323
324 /// Predict using a forest of trees. Returns (class, confidence).
325 fn predict_forest(trees: &[TreeNode], num_classes: usize, features: &[f64; NUM_FEATURES]) -> (u8, f64) {
326 let mut votes = vec![0u32; num_classes];
327 for tree in trees {
328 let class = tree.predict(features);
329 if (class as usize) < num_classes {
330 votes[class as usize] += 1;
331 }
332 }
333
334 let total = trees.len() as f64;
335 let (best_class, &best_count) = votes
336 .iter()
337 .enumerate()
338 .max_by_key(|(_, &c)| c)
339 .unwrap_or((0, &0));
340
341 (best_class as u8, best_count as f64 / total)
342 }
343
344 // ── Cross-validation ──
345
346 /// Stratified K-fold cross-validation.
347 fn cross_validate(
348 data: &[[f64; NUM_FEATURES]],
349 labels: &[u8],
350 class_names: &[&str],
351 k: usize,
352 n_trees: usize,
353 max_depth: usize,
354 min_leaf: usize,
355 ) {
356 let num_classes = class_names.len();
357 println!("\n{k}-fold stratified cross-validation:");
358 println!("{}", "=".repeat(60));
359
360 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes];
361 for (i, &label) in labels.iter().enumerate() {
362 class_indices[label as usize].push(i);
363 }
364
365 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
366 for indices in &mut class_indices {
367 indices.shuffle(&mut rng);
368 }
369
370 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); k];
371 for indices in &class_indices {
372 for (i, &idx) in indices.iter().enumerate() {
373 folds[i % k].push(idx);
374 }
375 }
376
377 let mut all_correct = 0usize;
378 let mut all_total = 0usize;
379 let mut per_class_tp = vec![0u32; num_classes];
380 let mut per_class_fp = vec![0u32; num_classes];
381 let mut per_class_fn = vec![0u32; num_classes];
382
383 for fold in 0..k {
384 let test_indices = &folds[fold];
385 let train_indices: Vec<usize> = (0..k)
386 .filter(|&f| f != fold)
387 .flat_map(|f| folds[f].iter().copied())
388 .collect();
389
390 let train_data: Vec<[f64; NUM_FEATURES]> =
391 train_indices.iter().map(|&i| data[i]).collect();
392 let train_labels: Vec<u8> = train_indices.iter().map(|&i| labels[i]).collect();
393
394 let trees = train_random_forest(&train_data, &train_labels, num_classes, n_trees, max_depth, min_leaf);
395
396 let mut fold_correct = 0;
397 for &test_idx in test_indices {
398 let (predicted, _conf) = predict_forest(&trees, num_classes, &data[test_idx]);
399 let actual = labels[test_idx];
400
401 if predicted == actual {
402 fold_correct += 1;
403 per_class_tp[actual as usize] += 1;
404 } else {
405 per_class_fp[predicted as usize] += 1;
406 per_class_fn[actual as usize] += 1;
407 }
408 }
409
410 let fold_acc = fold_correct as f64 / test_indices.len() as f64 * 100.0;
411 println!(" Fold {}: {:.1}% ({}/{})", fold + 1, fold_acc, fold_correct, test_indices.len());
412
413 all_correct += fold_correct;
414 all_total += test_indices.len();
415 }
416
417 let overall_acc = all_correct as f64 / all_total as f64 * 100.0;
418 println!("\n Overall: {:.1}% ({}/{})", overall_acc, all_correct, all_total);
419
420 println!("\n Per-class metrics:");
421 println!(" {:<14} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1");
422 for c in 0..num_classes {
423 let tp = per_class_tp[c] as f64;
424 let fp = per_class_fp[c] as f64;
425 let fn_ = per_class_fn[c] as f64;
426
427 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
428 let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
429 let f1 = if precision + recall > 0.0 {
430 2.0 * precision * recall / (precision + recall)
431 } else {
432 0.0
433 };
434
435 println!(
436 " {:<14} {:>8.1}% {:>8.1}% {:>8.1}%",
437 class_names[c],
438 precision * 100.0,
439 recall * 100.0,
440 f1 * 100.0
441 );
442 }
443
444 println!("{}", "=".repeat(60));
445 }
446
447 // ── Main ──
448
449 fn main() {
450 let args: Vec<String> = std::env::args().skip(1).collect();
451
452 // Parse --target flag
453 let mut target = "drum";
454 let mut data_path: Option<PathBuf> = None;
455 let mut i = 0;
456 while i < args.len() {
457 match args[i].as_str() {
458 "--target" => {
459 i += 1;
460 if i < args.len() {
461 target = Box::leak(args[i].clone().into_boxed_str());
462 } else {
463 eprintln!("--target requires a value (drum, bass, vocal, synth)");
464 std::process::exit(1);
465 }
466 }
467 "--help" | "-h" => {
468 println!("Usage: audiofiles-train [--target drum|bass|vocal|synth] [DATA_PATH]");
469 println!();
470 println!("Targets:");
471 println!(" drum 7 classes: kick, snare, hihat, cymbal, clap, tom, percussion");
472 println!(" bass 3 classes: guitar-bass, synth-bass, sub-bass");
473 println!(" vocal 3 classes: vocal-chop, vocal-phrase, vocal-choir");
474 println!(" synth 4 classes: synth-lead, synth-stab, synth-pluck, synth-chord");
475 println!();
476 println!("DATA_PATH defaults to samples/training/<target>/");
477 return;
478 }
479 other => {
480 if !other.starts_with('-') {
481 data_path = Some(PathBuf::from(other));
482 }
483 }
484 }
485 i += 1;
486 }
487
488 let config = target_config(target);
489 let num_classes = config.class_names.len();
490
491 println!("Training Layer 2 {} classifier ({} classes)", config.name, num_classes);
492 println!("Classes: {}", config.class_names.join(", "));
493
494 let test_data = data_path.unwrap_or_else(|| {
495 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
496 .parent()
497 .unwrap()
498 .parent()
499 .unwrap()
500 .join(format!("samples/training/{}", config.name))
501 });
502
503 if !test_data.exists() {
504 eprintln!("Training data not found at {}", test_data.display());
505 eprintln!(
506 "Expected structure: <dir>/{{{}}}/",
507 config.class_names.join(",")
508 );
509 eprintln!("Override: cargo run -p audiofiles-train -- --target {} /path/to/data", config.name);
510 std::process::exit(1);
511 }
512
513 // 1. Extract features from all labeled samples
514 println!("\nExtracting features from labeled samples...");
515 let mut all_data: Vec<[f64; NUM_FEATURES]> = Vec::new();
516 let mut all_labels: Vec<u8> = Vec::new();
517 let mut class_counts: HashMap<u8, usize> = HashMap::new();
518
519 for &label in &config.class_names {
520 let dir = test_data.join(label);
521 if !dir.exists() {
522 eprintln!(" Skipping {label}: directory not found");
523 continue;
524 }
525
526 let class_id = match label_to_class(label, &config) {
527 Some(id) => id,
528 None => continue,
529 };
530
531 let mut entries: Vec<_> = std::fs::read_dir(&dir)
532 .expect("failed to read dir")
533 .filter_map(|e| e.ok())
534 .filter(|e| {
535 let name = e.file_name().to_string_lossy().to_lowercase();
536 name.ends_with(".wav")
537 || name.ends_with(".aif")
538 || name.ends_with(".aiff")
539 || name.ends_with(".mp3")
540 || name.ends_with(".ogg")
541 || name.ends_with(".flac")
542 })
543 .collect();
544 entries.sort_by_key(|e| e.file_name());
545
546 let mut count = 0;
547 let mut failed = 0;
548
549 for entry in &entries {
550 let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path());
551 if !path.exists() {
552 continue;
553 }
554
555 match extract_features(&path) {
556 Some(features) => {
557 all_data.push(features);
558 all_labels.push(class_id);
559 count += 1;
560 }
561 None => {
562 failed += 1;
563 }
564 }
565 }
566
567 *class_counts.entry(class_id).or_default() += count;
568 println!(" {label:<14} {count:>4} samples extracted ({failed} failed)");
569 }
570
571 println!(
572 "\nTotal: {} samples across {} classes",
573 all_data.len(),
574 class_counts.len()
575 );
576 for (c, &name) in config.class_names.iter().enumerate() {
577 let count = class_counts.get(&(c as u8)).unwrap_or(&0);
578 println!(" {}: {} samples", name, count);
579 }
580
581 if all_data.len() < 100 {
582 eprintln!(
583 "Not enough samples for training (need at least 100, got {})",
584 all_data.len()
585 );
586 std::process::exit(1);
587 }
588
589 // Balance classes: cap each at the median class size
590 let max_per_class = {
591 let mut sizes: Vec<usize> = (0..num_classes as u8)
592 .map(|c| class_counts.get(&c).copied().unwrap_or(0))
593 .filter(|&s| s > 0)
594 .collect();
595 sizes.sort_unstable();
596 let median = sizes[sizes.len() / 2];
597 median.max(500)
598 };
599 println!("\nBalancing: capping each class at {max_per_class} samples");
600
601 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes];
602 for (i, &label) in all_labels.iter().enumerate() {
603 class_indices[label as usize].push(i);
604 }
605 let mut balance_rng = rand::rngs::StdRng::seed_from_u64(42);
606 for indices in &mut class_indices {
607 indices.shuffle(&mut balance_rng);
608 indices.truncate(max_per_class);
609 }
610 let keep: Vec<usize> = class_indices.into_iter().flatten().collect();
611 let balanced_data: Vec<[f64; NUM_FEATURES]> = keep.iter().map(|&i| all_data[i]).collect();
612 let balanced_labels: Vec<u8> = keep.iter().map(|&i| all_labels[i]).collect();
613
614 println!("Balanced dataset: {} samples", balanced_data.len());
615 for (c, &name) in config.class_names.iter().enumerate() {
616 let count = balanced_labels.iter().filter(|&&l| l == c as u8).count();
617 println!(" {}: {} samples", name, count);
618 }
619
620 // 2. Cross-validate
621 let n_trees = 200;
622 let max_depth = 25;
623 let min_leaf = 3;
624
625 cross_validate(
626 &balanced_data,
627 &balanced_labels,
628 &config.class_names,
629 5,
630 n_trees,
631 max_depth,
632 min_leaf,
633 );
634
635 // 3. Train final model on balanced data
636 println!(
637 "\nTraining final model on {} balanced samples ({n_trees} trees)...",
638 balanced_data.len()
639 );
640 let trees = train_random_forest(&balanced_data, &balanced_labels, num_classes, n_trees, max_depth, min_leaf);
641
642 // Quick sanity check: predict training set
643 let mut train_correct = 0;
644 for (row, &label) in balanced_data.iter().zip(balanced_labels.iter()) {
645 let (pred, _) = predict_forest(&trees, num_classes, row);
646 if pred == label {
647 train_correct += 1;
648 }
649 }
650 println!(
651 "Training set accuracy: {:.1}% ({}/{})",
652 train_correct as f64 / balanced_data.len() as f64 * 100.0,
653 train_correct,
654 balanced_data.len()
655 );
656
657 // 4. Serialize model
658 let model = RandomForestModel {
659 trees,
660 num_classes: num_classes as u8,
661 class_names: config.class_names.iter().map(|s| s.to_string()).collect(),
662 };
663
664 let model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
665 .parent()
666 .unwrap()
667 .join(format!("audiofiles-core/models/{}", config.model_filename));
668
669 let json = serde_json::to_string(&model).expect("failed to serialize model");
670 std::fs::write(&model_path, &json).expect("failed to write model file");
671
672 let size_mb = json.len() as f64 / 1_048_576.0;
673 println!(
674 "\nModel saved to {} ({:.1} MB)",
675 model_path.display(),
676 size_mb
677 );
678 println!("Trees: {}, Classes: {}", model.trees.len(), model.num_classes);
679 println!("\nDone. Rebuild audiofiles-core to embed the new model.");
680 }
681