//! Random Forest training binary for audiofiles Layer 2 classification. //! //! Supports training separate models for drum, bass, vocal, and synth sub-classes. //! Scans labeled samples from a training directory, extracts 35-feature vectors, //! trains a Random Forest, runs 5-fold stratified cross-validation, and serializes //! the model to the appropriate JSON file. //! //! Usage: //! cargo run -p audiofiles-train -- --target drum [/path/to/data] //! cargo run -p audiofiles-train -- --target bass [/path/to/data] //! cargo run -p audiofiles-train -- --target vocal [/path/to/data] //! cargo run -p audiofiles-train -- --target synth [/path/to/data] use std::collections::HashMap; use std::path::PathBuf; use audiofiles_core::analysis::classify::{RandomForestModel, TreeNode, NUM_FEATURES}; use audiofiles_core::analysis::config::AnalysisConfig; use audiofiles_core::analysis::{self, mfcc, spectral}; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; use rayon::prelude::*; // ── Target configuration ── /// Per-target training config: class names, label→ID mapping, output filename. struct TargetConfig { name: &'static str, class_names: Vec<&'static str>, model_filename: &'static str, } fn target_config(target: &str) -> TargetConfig { match target { "drum" => TargetConfig { name: "drum", class_names: vec!["kick", "snare", "hihat", "cymbal", "clap", "tom", "percussion"], model_filename: "layer2_drum.json", }, "bass" => TargetConfig { name: "bass", class_names: vec!["guitar-bass", "synth-bass", "sub-bass"], model_filename: "layer2_bass.json", }, "vocal" => TargetConfig { name: "vocal", class_names: vec!["vocal-chop", "vocal-phrase", "vocal-choir"], model_filename: "layer2_vocal.json", }, "synth" => TargetConfig { name: "synth", class_names: vec!["synth-lead", "synth-stab", "synth-pluck", "synth-chord"], model_filename: "layer2_synth.json", }, _ => { eprintln!("Unknown target: {target}"); eprintln!("Valid targets: drum, bass, vocal, synth"); std::process::exit(1); } } } /// Map a folder name to a class ID for the given target config. fn label_to_class(label: &str, config: &TargetConfig) -> Option { config .class_names .iter() .position(|&name| name == label) .map(|i| i as u8) } // ── Feature extraction ── /// Extract 35-feature vector from an audio file. fn extract_features(path: &std::path::Path) -> Option<[f64; NUM_FEATURES]> { let config = AnalysisConfig { loudness: true, spectral: true, bpm: false, key: false, loop_detect: false, classify: false, fingerprint: false, auto_suggest_tags: false, max_analysis_seconds: Some(5.0), smart_skip: false, }; let decoded = analysis::decode::decode_to_mono(path).ok()?; let capped: &[f32] = if let Some(max_secs) = config.max_analysis_seconds { let max_samples = (max_secs * decoded.sample_rate as f64) as usize; &decoded.samples[..decoded.samples.len().min(max_samples)] } else { &decoded.samples }; let crest_factor = analysis::basic::crest_factor(&decoded.samples); let attack_time = analysis::basic::attack_time(&decoded.samples, decoded.sample_rate); let (features, magnitude_frames) = spectral::compute_spectral_features_with_frames(capped, decoded.sample_rate); let mfcc_features = mfcc::compute_mfccs(&magnitude_frames, decoded.sample_rate, 1024); let input = audiofiles_core::analysis::classify::ClassifyInput::with_mfccs( &features, decoded.duration, crest_factor, attack_time, &mfcc_features, ); Some(input.to_feature_array()) } // ── Random Forest training ── /// Gini impurity for a set of class counts. fn gini_impurity(counts: &[u32], total: u32) -> f64 { if total == 0 { return 0.0; } let t = total as f64; 1.0 - counts.iter().map(|&c| (c as f64 / t).powi(2)).sum::() } /// Find the majority class in a label slice. fn majority_class(labels: &[u8], num_classes: usize) -> u8 { let mut counts = vec![0u32; num_classes]; for &l in labels { if (l as usize) < num_classes { counts[l as usize] += 1; } } counts .iter() .enumerate() .max_by_key(|(_, &c)| c) .map(|(i, _)| i as u8) .unwrap_or(0) } /// Train a single decision tree on a dataset with random feature selection. fn train_tree( data: &[[f64; NUM_FEATURES]], labels: &[u8], num_classes: usize, max_depth: usize, min_leaf: usize, features_per_split: usize, rng: &mut impl Rng, ) -> TreeNode { train_tree_recursive(data, labels, num_classes, max_depth, min_leaf, features_per_split, rng) } fn train_tree_recursive( data: &[[f64; NUM_FEATURES]], labels: &[u8], num_classes: usize, depth: usize, min_leaf: usize, features_per_split: usize, rng: &mut impl Rng, ) -> TreeNode { if depth == 0 || data.len() <= min_leaf { return TreeNode::Leaf { class: majority_class(labels, num_classes), }; } let first = labels[0]; if labels.iter().all(|&l| l == first) { return TreeNode::Leaf { class: first }; } let mut all_features: Vec = (0..NUM_FEATURES).collect(); all_features.shuffle(rng); let candidates = &all_features[..features_per_split.min(NUM_FEATURES)]; let mut best_gini = f64::MAX; let mut best_feature = 0; let mut best_threshold = 0.0; for &feat_idx in candidates { let mut values: Vec = data.iter().map(|row| row[feat_idx]).collect(); values.sort_by(|a, b| a.total_cmp(b)); values.dedup(); if values.len() < 2 { continue; } let max_thresholds = 200; let step = if values.len() > max_thresholds + 1 { values.len() / max_thresholds } else { 1 }; let mut i = 0; while i + 1 < values.len() { let threshold = (values[i] + values[i + 1]) / 2.0; let mut left_counts = vec![0u32; num_classes]; let mut right_counts = vec![0u32; num_classes]; let mut left_total = 0u32; let mut right_total = 0u32; for (row, &label) in data.iter().zip(labels.iter()) { if row[feat_idx] <= threshold { left_counts[label as usize] += 1; left_total += 1; } else { right_counts[label as usize] += 1; right_total += 1; } } if left_total < min_leaf as u32 || right_total < min_leaf as u32 { i += step; continue; } let total = (left_total + right_total) as f64; let weighted_gini = (left_total as f64 / total) * gini_impurity(&left_counts, left_total) + (right_total as f64 / total) * gini_impurity(&right_counts, right_total); if weighted_gini < best_gini { best_gini = weighted_gini; best_feature = feat_idx; best_threshold = threshold; } i += step; } } if best_gini == f64::MAX { return TreeNode::Leaf { class: majority_class(labels, num_classes), }; } let mut left_data = Vec::new(); let mut left_labels = Vec::new(); let mut right_data = Vec::new(); let mut right_labels = Vec::new(); for (row, &label) in data.iter().zip(labels.iter()) { if row[best_feature] <= best_threshold { left_data.push(*row); left_labels.push(label); } else { right_data.push(*row); right_labels.push(label); } } TreeNode::Split { feature: best_feature, threshold: best_threshold, left: Box::new(train_tree_recursive( &left_data, &left_labels, num_classes, depth - 1, min_leaf, features_per_split, rng, )), right: Box::new(train_tree_recursive( &right_data, &right_labels, num_classes, depth - 1, min_leaf, features_per_split, rng, )), } } /// Train a Random Forest: bootstrap sampling + random feature selection per tree. fn train_random_forest( data: &[[f64; NUM_FEATURES]], labels: &[u8], num_classes: usize, n_trees: usize, max_depth: usize, min_leaf: usize, ) -> Vec { let features_per_split = (NUM_FEATURES as f64).sqrt() as usize; // sqrt(35) ≈ 6 (0..n_trees) .into_par_iter() .map(|i| { let mut rng = rand::rngs::StdRng::seed_from_u64(42 + i as u64); let n = data.len(); let mut boot_data = Vec::with_capacity(n); let mut boot_labels = Vec::with_capacity(n); for _ in 0..n { let idx = rng.random_range(0..n); boot_data.push(data[idx]); boot_labels.push(labels[idx]); } train_tree( &boot_data, &boot_labels, num_classes, max_depth, min_leaf, features_per_split, &mut rng, ) }) .collect() } /// Predict using a forest of trees. Returns (class, confidence). fn predict_forest(trees: &[TreeNode], num_classes: usize, features: &[f64; NUM_FEATURES]) -> (u8, f64) { let mut votes = vec![0u32; num_classes]; for tree in trees { let class = tree.predict(features); if (class as usize) < num_classes { votes[class as usize] += 1; } } let total = trees.len() as f64; let (best_class, &best_count) = votes .iter() .enumerate() .max_by_key(|(_, &c)| c) .unwrap_or((0, &0)); (best_class as u8, best_count as f64 / total) } // ── Cross-validation ── /// Stratified K-fold cross-validation. fn cross_validate( data: &[[f64; NUM_FEATURES]], labels: &[u8], class_names: &[&str], k: usize, n_trees: usize, max_depth: usize, min_leaf: usize, ) { let num_classes = class_names.len(); println!("\n{k}-fold stratified cross-validation:"); println!("{}", "=".repeat(60)); let mut class_indices: Vec> = vec![Vec::new(); num_classes]; for (i, &label) in labels.iter().enumerate() { class_indices[label as usize].push(i); } let mut rng = rand::rngs::StdRng::seed_from_u64(123); for indices in &mut class_indices { indices.shuffle(&mut rng); } let mut folds: Vec> = vec![Vec::new(); k]; for indices in &class_indices { for (i, &idx) in indices.iter().enumerate() { folds[i % k].push(idx); } } let mut all_correct = 0usize; let mut all_total = 0usize; let mut per_class_tp = vec![0u32; num_classes]; let mut per_class_fp = vec![0u32; num_classes]; let mut per_class_fn = vec![0u32; num_classes]; for fold in 0..k { let test_indices = &folds[fold]; let train_indices: Vec = (0..k) .filter(|&f| f != fold) .flat_map(|f| folds[f].iter().copied()) .collect(); let train_data: Vec<[f64; NUM_FEATURES]> = train_indices.iter().map(|&i| data[i]).collect(); let train_labels: Vec = train_indices.iter().map(|&i| labels[i]).collect(); let trees = train_random_forest(&train_data, &train_labels, num_classes, n_trees, max_depth, min_leaf); let mut fold_correct = 0; for &test_idx in test_indices { let (predicted, _conf) = predict_forest(&trees, num_classes, &data[test_idx]); let actual = labels[test_idx]; if predicted == actual { fold_correct += 1; per_class_tp[actual as usize] += 1; } else { per_class_fp[predicted as usize] += 1; per_class_fn[actual as usize] += 1; } } let fold_acc = fold_correct as f64 / test_indices.len() as f64 * 100.0; println!(" Fold {}: {:.1}% ({}/{})", fold + 1, fold_acc, fold_correct, test_indices.len()); all_correct += fold_correct; all_total += test_indices.len(); } let overall_acc = all_correct as f64 / all_total as f64 * 100.0; println!("\n Overall: {:.1}% ({}/{})", overall_acc, all_correct, all_total); println!("\n Per-class metrics:"); println!(" {:<14} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1"); for c in 0..num_classes { let tp = per_class_tp[c] as f64; let fp = per_class_fp[c] as f64; let fn_ = per_class_fn[c] as f64; let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 }; let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 }; let f1 = if precision + recall > 0.0 { 2.0 * precision * recall / (precision + recall) } else { 0.0 }; println!( " {:<14} {:>8.1}% {:>8.1}% {:>8.1}%", class_names[c], precision * 100.0, recall * 100.0, f1 * 100.0 ); } println!("{}", "=".repeat(60)); } // ── Main ── fn main() { let args: Vec = std::env::args().skip(1).collect(); // Parse --target flag let mut target = "drum"; let mut data_path: Option = None; let mut i = 0; while i < args.len() { match args[i].as_str() { "--target" => { i += 1; if i < args.len() { target = Box::leak(args[i].clone().into_boxed_str()); } else { eprintln!("--target requires a value (drum, bass, vocal, synth)"); std::process::exit(1); } } "--help" | "-h" => { println!("Usage: audiofiles-train [--target drum|bass|vocal|synth] [DATA_PATH]"); println!(); println!("Targets:"); println!(" drum 7 classes: kick, snare, hihat, cymbal, clap, tom, percussion"); println!(" bass 3 classes: guitar-bass, synth-bass, sub-bass"); println!(" vocal 3 classes: vocal-chop, vocal-phrase, vocal-choir"); println!(" synth 4 classes: synth-lead, synth-stab, synth-pluck, synth-chord"); println!(); println!("DATA_PATH defaults to samples/training//"); return; } other => { if !other.starts_with('-') { data_path = Some(PathBuf::from(other)); } } } i += 1; } let config = target_config(target); let num_classes = config.class_names.len(); println!("Training Layer 2 {} classifier ({} classes)", config.name, num_classes); println!("Classes: {}", config.class_names.join(", ")); let test_data = data_path.unwrap_or_else(|| { PathBuf::from(env!("CARGO_MANIFEST_DIR")) .parent() .unwrap() .parent() .unwrap() .join(format!("samples/training/{}", config.name)) }); if !test_data.exists() { eprintln!("Training data not found at {}", test_data.display()); eprintln!( "Expected structure: /{{{}}}/", config.class_names.join(",") ); eprintln!("Override: cargo run -p audiofiles-train -- --target {} /path/to/data", config.name); std::process::exit(1); } // 1. Extract features from all labeled samples println!("\nExtracting features from labeled samples..."); let mut all_data: Vec<[f64; NUM_FEATURES]> = Vec::new(); let mut all_labels: Vec = Vec::new(); let mut class_counts: HashMap = HashMap::new(); for &label in &config.class_names { let dir = test_data.join(label); if !dir.exists() { eprintln!(" Skipping {label}: directory not found"); continue; } let class_id = match label_to_class(label, &config) { Some(id) => id, None => continue, }; let mut entries: Vec<_> = std::fs::read_dir(&dir) .expect("failed to read dir") .filter_map(|e| e.ok()) .filter(|e| { let name = e.file_name().to_string_lossy().to_lowercase(); name.ends_with(".wav") || name.ends_with(".aif") || name.ends_with(".aiff") || name.ends_with(".mp3") || name.ends_with(".ogg") || name.ends_with(".flac") }) .collect(); entries.sort_by_key(|e| e.file_name()); let mut count = 0; let mut failed = 0; for entry in &entries { let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path()); if !path.exists() { continue; } match extract_features(&path) { Some(features) => { all_data.push(features); all_labels.push(class_id); count += 1; } None => { failed += 1; } } } *class_counts.entry(class_id).or_default() += count; println!(" {label:<14} {count:>4} samples extracted ({failed} failed)"); } println!( "\nTotal: {} samples across {} classes", all_data.len(), class_counts.len() ); for (c, &name) in config.class_names.iter().enumerate() { let count = class_counts.get(&(c as u8)).unwrap_or(&0); println!(" {}: {} samples", name, count); } if all_data.len() < 100 { eprintln!( "Not enough samples for training (need at least 100, got {})", all_data.len() ); std::process::exit(1); } // Balance classes: cap each at the median class size let max_per_class = { let mut sizes: Vec = (0..num_classes as u8) .map(|c| class_counts.get(&c).copied().unwrap_or(0)) .filter(|&s| s > 0) .collect(); sizes.sort_unstable(); let median = sizes[sizes.len() / 2]; median.max(500) }; println!("\nBalancing: capping each class at {max_per_class} samples"); let mut class_indices: Vec> = vec![Vec::new(); num_classes]; for (i, &label) in all_labels.iter().enumerate() { class_indices[label as usize].push(i); } let mut balance_rng = rand::rngs::StdRng::seed_from_u64(42); for indices in &mut class_indices { indices.shuffle(&mut balance_rng); indices.truncate(max_per_class); } let keep: Vec = class_indices.into_iter().flatten().collect(); let balanced_data: Vec<[f64; NUM_FEATURES]> = keep.iter().map(|&i| all_data[i]).collect(); let balanced_labels: Vec = keep.iter().map(|&i| all_labels[i]).collect(); println!("Balanced dataset: {} samples", balanced_data.len()); for (c, &name) in config.class_names.iter().enumerate() { let count = balanced_labels.iter().filter(|&&l| l == c as u8).count(); println!(" {}: {} samples", name, count); } // 2. Cross-validate let n_trees = 200; let max_depth = 25; let min_leaf = 3; cross_validate( &balanced_data, &balanced_labels, &config.class_names, 5, n_trees, max_depth, min_leaf, ); // 3. Train final model on balanced data println!( "\nTraining final model on {} balanced samples ({n_trees} trees)...", balanced_data.len() ); let trees = train_random_forest(&balanced_data, &balanced_labels, num_classes, n_trees, max_depth, min_leaf); // Quick sanity check: predict training set let mut train_correct = 0; for (row, &label) in balanced_data.iter().zip(balanced_labels.iter()) { let (pred, _) = predict_forest(&trees, num_classes, row); if pred == label { train_correct += 1; } } println!( "Training set accuracy: {:.1}% ({}/{})", train_correct as f64 / balanced_data.len() as f64 * 100.0, train_correct, balanced_data.len() ); // 4. Serialize model let model = RandomForestModel { trees, num_classes: num_classes as u8, class_names: config.class_names.iter().map(|s| s.to_string()).collect(), }; let model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .parent() .unwrap() .join(format!("audiofiles-core/models/{}", config.model_filename)); let json = serde_json::to_string(&model).expect("failed to serialize model"); std::fs::write(&model_path, &json).expect("failed to write model file"); let size_mb = json.len() as f64 / 1_048_576.0; println!( "\nModel saved to {} ({:.1} MB)", model_path.display(), size_mb ); println!("Trees: {}, Classes: {}", model.trees.len(), model.num_classes); println!("\nDone. Rebuild audiofiles-core to embed the new model."); }