| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
use std::collections::HashMap; |
| 15 |
use std::path::PathBuf; |
| 16 |
|
| 17 |
use audiofiles_core::analysis::classify::{RandomForestModel, TreeNode, NUM_FEATURES}; |
| 18 |
use audiofiles_core::analysis::config::AnalysisConfig; |
| 19 |
use audiofiles_core::analysis::{self, mfcc, spectral}; |
| 20 |
use rand::seq::SliceRandom; |
| 21 |
use rand::{Rng, SeedableRng}; |
| 22 |
use rayon::prelude::*; |
| 23 |
|
| 24 |
|
| 25 |
|
| 26 |
|
| 27 |
struct TargetConfig { |
| 28 |
name: &'static str, |
| 29 |
class_names: Vec<&'static str>, |
| 30 |
model_filename: &'static str, |
| 31 |
} |
| 32 |
|
| 33 |
fn target_config(target: &str) -> TargetConfig { |
| 34 |
match target { |
| 35 |
"drum" => TargetConfig { |
| 36 |
name: "drum", |
| 37 |
class_names: vec!["kick", "snare", "hihat", "cymbal", "clap", "tom", "percussion"], |
| 38 |
model_filename: "layer2_drum.json", |
| 39 |
}, |
| 40 |
"bass" => TargetConfig { |
| 41 |
name: "bass", |
| 42 |
class_names: vec!["guitar-bass", "synth-bass", "sub-bass"], |
| 43 |
model_filename: "layer2_bass.json", |
| 44 |
}, |
| 45 |
"vocal" => TargetConfig { |
| 46 |
name: "vocal", |
| 47 |
class_names: vec!["vocal-chop", "vocal-phrase", "vocal-choir"], |
| 48 |
model_filename: "layer2_vocal.json", |
| 49 |
}, |
| 50 |
"synth" => TargetConfig { |
| 51 |
name: "synth", |
| 52 |
class_names: vec!["synth-lead", "synth-stab", "synth-pluck", "synth-chord"], |
| 53 |
model_filename: "layer2_synth.json", |
| 54 |
}, |
| 55 |
_ => { |
| 56 |
eprintln!("Unknown target: {target}"); |
| 57 |
eprintln!("Valid targets: drum, bass, vocal, synth"); |
| 58 |
std::process::exit(1); |
| 59 |
} |
| 60 |
} |
| 61 |
} |
| 62 |
|
| 63 |
|
| 64 |
fn label_to_class(label: &str, config: &TargetConfig) -> Option<u8> { |
| 65 |
config |
| 66 |
.class_names |
| 67 |
.iter() |
| 68 |
.position(|&name| name == label) |
| 69 |
.map(|i| i as u8) |
| 70 |
} |
| 71 |
|
| 72 |
|
| 73 |
|
| 74 |
|
| 75 |
fn extract_features(path: &std::path::Path) -> Option<[f64; NUM_FEATURES]> { |
| 76 |
let config = AnalysisConfig { |
| 77 |
loudness: true, |
| 78 |
spectral: true, |
| 79 |
bpm: false, |
| 80 |
key: false, |
| 81 |
loop_detect: false, |
| 82 |
classify: false, |
| 83 |
fingerprint: false, |
| 84 |
auto_suggest_tags: false, |
| 85 |
max_analysis_seconds: Some(5.0), |
| 86 |
smart_skip: false, |
| 87 |
}; |
| 88 |
|
| 89 |
let decoded = analysis::decode::decode_to_mono(path).ok()?; |
| 90 |
|
| 91 |
let capped: &[f32] = if let Some(max_secs) = config.max_analysis_seconds { |
| 92 |
let max_samples = (max_secs * decoded.sample_rate as f64) as usize; |
| 93 |
&decoded.samples[..decoded.samples.len().min(max_samples)] |
| 94 |
} else { |
| 95 |
&decoded.samples |
| 96 |
}; |
| 97 |
|
| 98 |
let crest_factor = analysis::basic::crest_factor(&decoded.samples); |
| 99 |
let attack_time = analysis::basic::attack_time(&decoded.samples, decoded.sample_rate); |
| 100 |
|
| 101 |
let (features, magnitude_frames) = |
| 102 |
spectral::compute_spectral_features_with_frames(capped, decoded.sample_rate); |
| 103 |
|
| 104 |
let mfcc_features = mfcc::compute_mfccs(&magnitude_frames, decoded.sample_rate, 1024); |
| 105 |
|
| 106 |
let input = audiofiles_core::analysis::classify::ClassifyInput::with_mfccs( |
| 107 |
&features, |
| 108 |
decoded.duration, |
| 109 |
crest_factor, |
| 110 |
attack_time, |
| 111 |
&mfcc_features, |
| 112 |
); |
| 113 |
|
| 114 |
Some(input.to_feature_array()) |
| 115 |
} |
| 116 |
|
| 117 |
|
| 118 |
|
| 119 |
|
| 120 |
fn gini_impurity(counts: &[u32], total: u32) -> f64 { |
| 121 |
if total == 0 { |
| 122 |
return 0.0; |
| 123 |
} |
| 124 |
let t = total as f64; |
| 125 |
1.0 - counts.iter().map(|&c| (c as f64 / t).powi(2)).sum::<f64>() |
| 126 |
} |
| 127 |
|
| 128 |
|
| 129 |
fn majority_class(labels: &[u8], num_classes: usize) -> u8 { |
| 130 |
let mut counts = vec![0u32; num_classes]; |
| 131 |
for &l in labels { |
| 132 |
if (l as usize) < num_classes { |
| 133 |
counts[l as usize] += 1; |
| 134 |
} |
| 135 |
} |
| 136 |
counts |
| 137 |
.iter() |
| 138 |
.enumerate() |
| 139 |
.max_by_key(|(_, &c)| c) |
| 140 |
.map(|(i, _)| i as u8) |
| 141 |
.unwrap_or(0) |
| 142 |
} |
| 143 |
|
| 144 |
|
| 145 |
fn train_tree( |
| 146 |
data: &[[f64; NUM_FEATURES]], |
| 147 |
labels: &[u8], |
| 148 |
num_classes: usize, |
| 149 |
max_depth: usize, |
| 150 |
min_leaf: usize, |
| 151 |
features_per_split: usize, |
| 152 |
rng: &mut impl Rng, |
| 153 |
) -> TreeNode { |
| 154 |
train_tree_recursive(data, labels, num_classes, max_depth, min_leaf, features_per_split, rng) |
| 155 |
} |
| 156 |
|
| 157 |
fn train_tree_recursive( |
| 158 |
data: &[[f64; NUM_FEATURES]], |
| 159 |
labels: &[u8], |
| 160 |
num_classes: usize, |
| 161 |
depth: usize, |
| 162 |
min_leaf: usize, |
| 163 |
features_per_split: usize, |
| 164 |
rng: &mut impl Rng, |
| 165 |
) -> TreeNode { |
| 166 |
if depth == 0 || data.len() <= min_leaf { |
| 167 |
return TreeNode::Leaf { |
| 168 |
class: majority_class(labels, num_classes), |
| 169 |
}; |
| 170 |
} |
| 171 |
|
| 172 |
let first = labels[0]; |
| 173 |
if labels.iter().all(|&l| l == first) { |
| 174 |
return TreeNode::Leaf { class: first }; |
| 175 |
} |
| 176 |
|
| 177 |
let mut all_features: Vec<usize> = (0..NUM_FEATURES).collect(); |
| 178 |
all_features.shuffle(rng); |
| 179 |
let candidates = &all_features[..features_per_split.min(NUM_FEATURES)]; |
| 180 |
|
| 181 |
let mut best_gini = f64::MAX; |
| 182 |
let mut best_feature = 0; |
| 183 |
let mut best_threshold = 0.0; |
| 184 |
|
| 185 |
for &feat_idx in candidates { |
| 186 |
let mut values: Vec<f64> = data.iter().map(|row| row[feat_idx]).collect(); |
| 187 |
values.sort_by(|a, b| a.total_cmp(b)); |
| 188 |
values.dedup(); |
| 189 |
|
| 190 |
if values.len() < 2 { |
| 191 |
continue; |
| 192 |
} |
| 193 |
|
| 194 |
let max_thresholds = 200; |
| 195 |
let step = if values.len() > max_thresholds + 1 { |
| 196 |
values.len() / max_thresholds |
| 197 |
} else { |
| 198 |
1 |
| 199 |
}; |
| 200 |
|
| 201 |
let mut i = 0; |
| 202 |
while i + 1 < values.len() { |
| 203 |
let threshold = (values[i] + values[i + 1]) / 2.0; |
| 204 |
|
| 205 |
let mut left_counts = vec![0u32; num_classes]; |
| 206 |
let mut right_counts = vec![0u32; num_classes]; |
| 207 |
let mut left_total = 0u32; |
| 208 |
let mut right_total = 0u32; |
| 209 |
|
| 210 |
for (row, &label) in data.iter().zip(labels.iter()) { |
| 211 |
if row[feat_idx] <= threshold { |
| 212 |
left_counts[label as usize] += 1; |
| 213 |
left_total += 1; |
| 214 |
} else { |
| 215 |
right_counts[label as usize] += 1; |
| 216 |
right_total += 1; |
| 217 |
} |
| 218 |
} |
| 219 |
|
| 220 |
if left_total < min_leaf as u32 || right_total < min_leaf as u32 { |
| 221 |
i += step; |
| 222 |
continue; |
| 223 |
} |
| 224 |
|
| 225 |
let total = (left_total + right_total) as f64; |
| 226 |
let weighted_gini = (left_total as f64 / total) |
| 227 |
* gini_impurity(&left_counts, left_total) |
| 228 |
+ (right_total as f64 / total) |
| 229 |
* gini_impurity(&right_counts, right_total); |
| 230 |
|
| 231 |
if weighted_gini < best_gini { |
| 232 |
best_gini = weighted_gini; |
| 233 |
best_feature = feat_idx; |
| 234 |
best_threshold = threshold; |
| 235 |
} |
| 236 |
|
| 237 |
i += step; |
| 238 |
} |
| 239 |
} |
| 240 |
|
| 241 |
if best_gini == f64::MAX { |
| 242 |
return TreeNode::Leaf { |
| 243 |
class: majority_class(labels, num_classes), |
| 244 |
}; |
| 245 |
} |
| 246 |
|
| 247 |
let mut left_data = Vec::new(); |
| 248 |
let mut left_labels = Vec::new(); |
| 249 |
let mut right_data = Vec::new(); |
| 250 |
let mut right_labels = Vec::new(); |
| 251 |
|
| 252 |
for (row, &label) in data.iter().zip(labels.iter()) { |
| 253 |
if row[best_feature] <= best_threshold { |
| 254 |
left_data.push(*row); |
| 255 |
left_labels.push(label); |
| 256 |
} else { |
| 257 |
right_data.push(*row); |
| 258 |
right_labels.push(label); |
| 259 |
} |
| 260 |
} |
| 261 |
|
| 262 |
TreeNode::Split { |
| 263 |
feature: best_feature, |
| 264 |
threshold: best_threshold, |
| 265 |
left: Box::new(train_tree_recursive( |
| 266 |
&left_data, |
| 267 |
&left_labels, |
| 268 |
num_classes, |
| 269 |
depth - 1, |
| 270 |
min_leaf, |
| 271 |
features_per_split, |
| 272 |
rng, |
| 273 |
)), |
| 274 |
right: Box::new(train_tree_recursive( |
| 275 |
&right_data, |
| 276 |
&right_labels, |
| 277 |
num_classes, |
| 278 |
depth - 1, |
| 279 |
min_leaf, |
| 280 |
features_per_split, |
| 281 |
rng, |
| 282 |
)), |
| 283 |
} |
| 284 |
} |
| 285 |
|
| 286 |
|
| 287 |
fn train_random_forest( |
| 288 |
data: &[[f64; NUM_FEATURES]], |
| 289 |
labels: &[u8], |
| 290 |
num_classes: usize, |
| 291 |
n_trees: usize, |
| 292 |
max_depth: usize, |
| 293 |
min_leaf: usize, |
| 294 |
) -> Vec<TreeNode> { |
| 295 |
let features_per_split = (NUM_FEATURES as f64).sqrt() as usize; |
| 296 |
|
| 297 |
(0..n_trees) |
| 298 |
.into_par_iter() |
| 299 |
.map(|i| { |
| 300 |
let mut rng = rand::rngs::StdRng::seed_from_u64(42 + i as u64); |
| 301 |
let n = data.len(); |
| 302 |
|
| 303 |
let mut boot_data = Vec::with_capacity(n); |
| 304 |
let mut boot_labels = Vec::with_capacity(n); |
| 305 |
for _ in 0..n { |
| 306 |
let idx = rng.random_range(0..n); |
| 307 |
boot_data.push(data[idx]); |
| 308 |
boot_labels.push(labels[idx]); |
| 309 |
} |
| 310 |
|
| 311 |
train_tree( |
| 312 |
&boot_data, |
| 313 |
&boot_labels, |
| 314 |
num_classes, |
| 315 |
max_depth, |
| 316 |
min_leaf, |
| 317 |
features_per_split, |
| 318 |
&mut rng, |
| 319 |
) |
| 320 |
}) |
| 321 |
.collect() |
| 322 |
} |
| 323 |
|
| 324 |
|
| 325 |
fn predict_forest(trees: &[TreeNode], num_classes: usize, features: &[f64; NUM_FEATURES]) -> (u8, f64) { |
| 326 |
let mut votes = vec![0u32; num_classes]; |
| 327 |
for tree in trees { |
| 328 |
let class = tree.predict(features); |
| 329 |
if (class as usize) < num_classes { |
| 330 |
votes[class as usize] += 1; |
| 331 |
} |
| 332 |
} |
| 333 |
|
| 334 |
let total = trees.len() as f64; |
| 335 |
let (best_class, &best_count) = votes |
| 336 |
.iter() |
| 337 |
.enumerate() |
| 338 |
.max_by_key(|(_, &c)| c) |
| 339 |
.unwrap_or((0, &0)); |
| 340 |
|
| 341 |
(best_class as u8, best_count as f64 / total) |
| 342 |
} |
| 343 |
|
| 344 |
|
| 345 |
|
| 346 |
|
| 347 |
fn cross_validate( |
| 348 |
data: &[[f64; NUM_FEATURES]], |
| 349 |
labels: &[u8], |
| 350 |
class_names: &[&str], |
| 351 |
k: usize, |
| 352 |
n_trees: usize, |
| 353 |
max_depth: usize, |
| 354 |
min_leaf: usize, |
| 355 |
) { |
| 356 |
let num_classes = class_names.len(); |
| 357 |
println!("\n{k}-fold stratified cross-validation:"); |
| 358 |
println!("{}", "=".repeat(60)); |
| 359 |
|
| 360 |
let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes]; |
| 361 |
for (i, &label) in labels.iter().enumerate() { |
| 362 |
class_indices[label as usize].push(i); |
| 363 |
} |
| 364 |
|
| 365 |
let mut rng = rand::rngs::StdRng::seed_from_u64(123); |
| 366 |
for indices in &mut class_indices { |
| 367 |
indices.shuffle(&mut rng); |
| 368 |
} |
| 369 |
|
| 370 |
let mut folds: Vec<Vec<usize>> = vec![Vec::new(); k]; |
| 371 |
for indices in &class_indices { |
| 372 |
for (i, &idx) in indices.iter().enumerate() { |
| 373 |
folds[i % k].push(idx); |
| 374 |
} |
| 375 |
} |
| 376 |
|
| 377 |
let mut all_correct = 0usize; |
| 378 |
let mut all_total = 0usize; |
| 379 |
let mut per_class_tp = vec![0u32; num_classes]; |
| 380 |
let mut per_class_fp = vec![0u32; num_classes]; |
| 381 |
let mut per_class_fn = vec![0u32; num_classes]; |
| 382 |
|
| 383 |
for fold in 0..k { |
| 384 |
let test_indices = &folds[fold]; |
| 385 |
let train_indices: Vec<usize> = (0..k) |
| 386 |
.filter(|&f| f != fold) |
| 387 |
.flat_map(|f| folds[f].iter().copied()) |
| 388 |
.collect(); |
| 389 |
|
| 390 |
let train_data: Vec<[f64; NUM_FEATURES]> = |
| 391 |
train_indices.iter().map(|&i| data[i]).collect(); |
| 392 |
let train_labels: Vec<u8> = train_indices.iter().map(|&i| labels[i]).collect(); |
| 393 |
|
| 394 |
let trees = train_random_forest(&train_data, &train_labels, num_classes, n_trees, max_depth, min_leaf); |
| 395 |
|
| 396 |
let mut fold_correct = 0; |
| 397 |
for &test_idx in test_indices { |
| 398 |
let (predicted, _conf) = predict_forest(&trees, num_classes, &data[test_idx]); |
| 399 |
let actual = labels[test_idx]; |
| 400 |
|
| 401 |
if predicted == actual { |
| 402 |
fold_correct += 1; |
| 403 |
per_class_tp[actual as usize] += 1; |
| 404 |
} else { |
| 405 |
per_class_fp[predicted as usize] += 1; |
| 406 |
per_class_fn[actual as usize] += 1; |
| 407 |
} |
| 408 |
} |
| 409 |
|
| 410 |
let fold_acc = fold_correct as f64 / test_indices.len() as f64 * 100.0; |
| 411 |
println!(" Fold {}: {:.1}% ({}/{})", fold + 1, fold_acc, fold_correct, test_indices.len()); |
| 412 |
|
| 413 |
all_correct += fold_correct; |
| 414 |
all_total += test_indices.len(); |
| 415 |
} |
| 416 |
|
| 417 |
let overall_acc = all_correct as f64 / all_total as f64 * 100.0; |
| 418 |
println!("\n Overall: {:.1}% ({}/{})", overall_acc, all_correct, all_total); |
| 419 |
|
| 420 |
println!("\n Per-class metrics:"); |
| 421 |
println!(" {:<14} {:>9} {:>9} {:>9}", "Class", "Precision", "Recall", "F1"); |
| 422 |
for c in 0..num_classes { |
| 423 |
let tp = per_class_tp[c] as f64; |
| 424 |
let fp = per_class_fp[c] as f64; |
| 425 |
let fn_ = per_class_fn[c] as f64; |
| 426 |
|
| 427 |
let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 }; |
| 428 |
let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 }; |
| 429 |
let f1 = if precision + recall > 0.0 { |
| 430 |
2.0 * precision * recall / (precision + recall) |
| 431 |
} else { |
| 432 |
0.0 |
| 433 |
}; |
| 434 |
|
| 435 |
println!( |
| 436 |
" {:<14} {:>8.1}% {:>8.1}% {:>8.1}%", |
| 437 |
class_names[c], |
| 438 |
precision * 100.0, |
| 439 |
recall * 100.0, |
| 440 |
f1 * 100.0 |
| 441 |
); |
| 442 |
} |
| 443 |
|
| 444 |
println!("{}", "=".repeat(60)); |
| 445 |
} |
| 446 |
|
| 447 |
|
| 448 |
|
| 449 |
fn main() { |
| 450 |
let args: Vec<String> = std::env::args().skip(1).collect(); |
| 451 |
|
| 452 |
|
| 453 |
let mut target = "drum"; |
| 454 |
let mut data_path: Option<PathBuf> = None; |
| 455 |
let mut i = 0; |
| 456 |
while i < args.len() { |
| 457 |
match args[i].as_str() { |
| 458 |
"--target" => { |
| 459 |
i += 1; |
| 460 |
if i < args.len() { |
| 461 |
target = Box::leak(args[i].clone().into_boxed_str()); |
| 462 |
} else { |
| 463 |
eprintln!("--target requires a value (drum, bass, vocal, synth)"); |
| 464 |
std::process::exit(1); |
| 465 |
} |
| 466 |
} |
| 467 |
"--help" | "-h" => { |
| 468 |
println!("Usage: audiofiles-train [--target drum|bass|vocal|synth] [DATA_PATH]"); |
| 469 |
println!(); |
| 470 |
println!("Targets:"); |
| 471 |
println!(" drum 7 classes: kick, snare, hihat, cymbal, clap, tom, percussion"); |
| 472 |
println!(" bass 3 classes: guitar-bass, synth-bass, sub-bass"); |
| 473 |
println!(" vocal 3 classes: vocal-chop, vocal-phrase, vocal-choir"); |
| 474 |
println!(" synth 4 classes: synth-lead, synth-stab, synth-pluck, synth-chord"); |
| 475 |
println!(); |
| 476 |
println!("DATA_PATH defaults to samples/training/<target>/"); |
| 477 |
return; |
| 478 |
} |
| 479 |
other => { |
| 480 |
if !other.starts_with('-') { |
| 481 |
data_path = Some(PathBuf::from(other)); |
| 482 |
} |
| 483 |
} |
| 484 |
} |
| 485 |
i += 1; |
| 486 |
} |
| 487 |
|
| 488 |
let config = target_config(target); |
| 489 |
let num_classes = config.class_names.len(); |
| 490 |
|
| 491 |
println!("Training Layer 2 {} classifier ({} classes)", config.name, num_classes); |
| 492 |
println!("Classes: {}", config.class_names.join(", ")); |
| 493 |
|
| 494 |
let test_data = data_path.unwrap_or_else(|| { |
| 495 |
PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
| 496 |
.parent() |
| 497 |
.unwrap() |
| 498 |
.parent() |
| 499 |
.unwrap() |
| 500 |
.join(format!("samples/training/{}", config.name)) |
| 501 |
}); |
| 502 |
|
| 503 |
if !test_data.exists() { |
| 504 |
eprintln!("Training data not found at {}", test_data.display()); |
| 505 |
eprintln!( |
| 506 |
"Expected structure: <dir>/{{{}}}/", |
| 507 |
config.class_names.join(",") |
| 508 |
); |
| 509 |
eprintln!("Override: cargo run -p audiofiles-train -- --target {} /path/to/data", config.name); |
| 510 |
std::process::exit(1); |
| 511 |
} |
| 512 |
|
| 513 |
|
| 514 |
println!("\nExtracting features from labeled samples..."); |
| 515 |
let mut all_data: Vec<[f64; NUM_FEATURES]> = Vec::new(); |
| 516 |
let mut all_labels: Vec<u8> = Vec::new(); |
| 517 |
let mut class_counts: HashMap<u8, usize> = HashMap::new(); |
| 518 |
|
| 519 |
for &label in &config.class_names { |
| 520 |
let dir = test_data.join(label); |
| 521 |
if !dir.exists() { |
| 522 |
eprintln!(" Skipping {label}: directory not found"); |
| 523 |
continue; |
| 524 |
} |
| 525 |
|
| 526 |
let class_id = match label_to_class(label, &config) { |
| 527 |
Some(id) => id, |
| 528 |
None => continue, |
| 529 |
}; |
| 530 |
|
| 531 |
let mut entries: Vec<_> = std::fs::read_dir(&dir) |
| 532 |
.expect("failed to read dir") |
| 533 |
.filter_map(|e| e.ok()) |
| 534 |
.filter(|e| { |
| 535 |
let name = e.file_name().to_string_lossy().to_lowercase(); |
| 536 |
name.ends_with(".wav") |
| 537 |
|| name.ends_with(".aif") |
| 538 |
|| name.ends_with(".aiff") |
| 539 |
|| name.ends_with(".mp3") |
| 540 |
|| name.ends_with(".ogg") |
| 541 |
|| name.ends_with(".flac") |
| 542 |
}) |
| 543 |
.collect(); |
| 544 |
entries.sort_by_key(|e| e.file_name()); |
| 545 |
|
| 546 |
let mut count = 0; |
| 547 |
let mut failed = 0; |
| 548 |
|
| 549 |
for entry in &entries { |
| 550 |
let path = std::fs::read_link(entry.path()).unwrap_or_else(|_| entry.path()); |
| 551 |
if !path.exists() { |
| 552 |
continue; |
| 553 |
} |
| 554 |
|
| 555 |
match extract_features(&path) { |
| 556 |
Some(features) => { |
| 557 |
all_data.push(features); |
| 558 |
all_labels.push(class_id); |
| 559 |
count += 1; |
| 560 |
} |
| 561 |
None => { |
| 562 |
failed += 1; |
| 563 |
} |
| 564 |
} |
| 565 |
} |
| 566 |
|
| 567 |
*class_counts.entry(class_id).or_default() += count; |
| 568 |
println!(" {label:<14} {count:>4} samples extracted ({failed} failed)"); |
| 569 |
} |
| 570 |
|
| 571 |
println!( |
| 572 |
"\nTotal: {} samples across {} classes", |
| 573 |
all_data.len(), |
| 574 |
class_counts.len() |
| 575 |
); |
| 576 |
for (c, &name) in config.class_names.iter().enumerate() { |
| 577 |
let count = class_counts.get(&(c as u8)).unwrap_or(&0); |
| 578 |
println!(" {}: {} samples", name, count); |
| 579 |
} |
| 580 |
|
| 581 |
if all_data.len() < 100 { |
| 582 |
eprintln!( |
| 583 |
"Not enough samples for training (need at least 100, got {})", |
| 584 |
all_data.len() |
| 585 |
); |
| 586 |
std::process::exit(1); |
| 587 |
} |
| 588 |
|
| 589 |
|
| 590 |
let max_per_class = { |
| 591 |
let mut sizes: Vec<usize> = (0..num_classes as u8) |
| 592 |
.map(|c| class_counts.get(&c).copied().unwrap_or(0)) |
| 593 |
.filter(|&s| s > 0) |
| 594 |
.collect(); |
| 595 |
sizes.sort_unstable(); |
| 596 |
let median = sizes[sizes.len() / 2]; |
| 597 |
median.max(500) |
| 598 |
}; |
| 599 |
println!("\nBalancing: capping each class at {max_per_class} samples"); |
| 600 |
|
| 601 |
let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); num_classes]; |
| 602 |
for (i, &label) in all_labels.iter().enumerate() { |
| 603 |
class_indices[label as usize].push(i); |
| 604 |
} |
| 605 |
let mut balance_rng = rand::rngs::StdRng::seed_from_u64(42); |
| 606 |
for indices in &mut class_indices { |
| 607 |
indices.shuffle(&mut balance_rng); |
| 608 |
indices.truncate(max_per_class); |
| 609 |
} |
| 610 |
let keep: Vec<usize> = class_indices.into_iter().flatten().collect(); |
| 611 |
let balanced_data: Vec<[f64; NUM_FEATURES]> = keep.iter().map(|&i| all_data[i]).collect(); |
| 612 |
let balanced_labels: Vec<u8> = keep.iter().map(|&i| all_labels[i]).collect(); |
| 613 |
|
| 614 |
println!("Balanced dataset: {} samples", balanced_data.len()); |
| 615 |
for (c, &name) in config.class_names.iter().enumerate() { |
| 616 |
let count = balanced_labels.iter().filter(|&&l| l == c as u8).count(); |
| 617 |
println!(" {}: {} samples", name, count); |
| 618 |
} |
| 619 |
|
| 620 |
|
| 621 |
let n_trees = 200; |
| 622 |
let max_depth = 25; |
| 623 |
let min_leaf = 3; |
| 624 |
|
| 625 |
cross_validate( |
| 626 |
&balanced_data, |
| 627 |
&balanced_labels, |
| 628 |
&config.class_names, |
| 629 |
5, |
| 630 |
n_trees, |
| 631 |
max_depth, |
| 632 |
min_leaf, |
| 633 |
); |
| 634 |
|
| 635 |
|
| 636 |
println!( |
| 637 |
"\nTraining final model on {} balanced samples ({n_trees} trees)...", |
| 638 |
balanced_data.len() |
| 639 |
); |
| 640 |
let trees = train_random_forest(&balanced_data, &balanced_labels, num_classes, n_trees, max_depth, min_leaf); |
| 641 |
|
| 642 |
|
| 643 |
let mut train_correct = 0; |
| 644 |
for (row, &label) in balanced_data.iter().zip(balanced_labels.iter()) { |
| 645 |
let (pred, _) = predict_forest(&trees, num_classes, row); |
| 646 |
if pred == label { |
| 647 |
train_correct += 1; |
| 648 |
} |
| 649 |
} |
| 650 |
println!( |
| 651 |
"Training set accuracy: {:.1}% ({}/{})", |
| 652 |
train_correct as f64 / balanced_data.len() as f64 * 100.0, |
| 653 |
train_correct, |
| 654 |
balanced_data.len() |
| 655 |
); |
| 656 |
|
| 657 |
|
| 658 |
let model = RandomForestModel { |
| 659 |
trees, |
| 660 |
num_classes: num_classes as u8, |
| 661 |
class_names: config.class_names.iter().map(|s| s.to_string()).collect(), |
| 662 |
}; |
| 663 |
|
| 664 |
let model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
| 665 |
.parent() |
| 666 |
.unwrap() |
| 667 |
.join(format!("audiofiles-core/models/{}", config.model_filename)); |
| 668 |
|
| 669 |
let json = serde_json::to_string(&model).expect("failed to serialize model"); |
| 670 |
std::fs::write(&model_path, &json).expect("failed to write model file"); |
| 671 |
|
| 672 |
let size_mb = json.len() as f64 / 1_048_576.0; |
| 673 |
println!( |
| 674 |
"\nModel saved to {} ({:.1} MB)", |
| 675 |
model_path.display(), |
| 676 |
size_mb |
| 677 |
); |
| 678 |
println!("Trees: {}, Classes: {}", model.trees.len(), model.num_classes); |
| 679 |
println!("\nDone. Rebuild audiofiles-core to embed the new model."); |
| 680 |
} |
| 681 |
|