| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
use std::path::Path; |
| 7 |
|
| 8 |
use symphonia::core::audio::SampleBuffer; |
| 9 |
use symphonia::core::codecs::DecoderOptions; |
| 10 |
use symphonia::core::formats::FormatOptions; |
| 11 |
use symphonia::core::io::MediaSourceStream; |
| 12 |
use symphonia::core::meta::MetadataOptions; |
| 13 |
use symphonia::core::probe::Hint; |
| 14 |
use tracing::instrument; |
| 15 |
|
| 16 |
use crate::error::{io_err, AnalysisError, CoreError}; |
| 17 |
|
| 18 |
|
| 19 |
#[derive(Debug)] |
| 20 |
pub struct DecodedMultichannel { |
| 21 |
|
| 22 |
pub samples: Vec<f32>, |
| 23 |
|
| 24 |
pub sample_rate: u32, |
| 25 |
|
| 26 |
pub channels: u16, |
| 27 |
} |
| 28 |
|
| 29 |
|
| 30 |
#[instrument(skip_all)] |
| 31 |
pub fn decode_multichannel(path: &Path) -> Result<DecodedMultichannel, CoreError> { |
| 32 |
let file = std::fs::File::open(path).map_err(|e| io_err(path, e))?; |
| 33 |
let mss = MediaSourceStream::new(Box::new(file), Default::default()); |
| 34 |
|
| 35 |
let mut hint = Hint::new(); |
| 36 |
if let Some(ext) = path.extension().and_then(|e| e.to_str()) { |
| 37 |
hint.with_extension(ext); |
| 38 |
} |
| 39 |
|
| 40 |
let probed = symphonia::default::get_probe() |
| 41 |
.format( |
| 42 |
&hint, |
| 43 |
mss, |
| 44 |
&FormatOptions::default(), |
| 45 |
&MetadataOptions::default(), |
| 46 |
) |
| 47 |
.map_err(|e| AnalysisError::ProbeFailed(e.to_string()))?; |
| 48 |
|
| 49 |
let mut format = probed.format; |
| 50 |
|
| 51 |
let track = format |
| 52 |
.default_track() |
| 53 |
.ok_or(AnalysisError::NoAudioTrack)?; |
| 54 |
|
| 55 |
let track_id = track.id; |
| 56 |
let source_sample_rate = track |
| 57 |
.codec_params |
| 58 |
.sample_rate |
| 59 |
.ok_or(AnalysisError::ProbeFailed("missing sample rate".to_string()))?; |
| 60 |
let source_channels = track |
| 61 |
.codec_params |
| 62 |
.channels |
| 63 |
.map(|c| c.count() as u16) |
| 64 |
.ok_or(AnalysisError::ProbeFailed("missing channel count".to_string()))?; |
| 65 |
|
| 66 |
let mut decoder = symphonia::default::get_codecs() |
| 67 |
.make(&track.codec_params, &DecoderOptions::default()) |
| 68 |
.map_err(|e| AnalysisError::DecoderFailed(e.to_string()))?; |
| 69 |
|
| 70 |
let mut all_samples: Vec<f32> = Vec::new(); |
| 71 |
|
| 72 |
loop { |
| 73 |
let packet = match format.next_packet() { |
| 74 |
Ok(p) => p, |
| 75 |
Err(symphonia::core::errors::Error::IoError(ref e)) |
| 76 |
if e.kind() == std::io::ErrorKind::UnexpectedEof => |
| 77 |
{ |
| 78 |
break; |
| 79 |
} |
| 80 |
Err(e) => return Err(AnalysisError::PacketError(e.to_string()).into()), |
| 81 |
}; |
| 82 |
|
| 83 |
if packet.track_id() != track_id { |
| 84 |
continue; |
| 85 |
} |
| 86 |
|
| 87 |
let decoded = match decoder.decode(&packet) { |
| 88 |
Ok(d) => d, |
| 89 |
Err(symphonia::core::errors::Error::DecodeError(_)) => continue, |
| 90 |
Err(e) => return Err(AnalysisError::DecodeError(e.to_string()).into()), |
| 91 |
}; |
| 92 |
|
| 93 |
let num_frames = decoded.frames(); |
| 94 |
let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, *decoded.spec()); |
| 95 |
sample_buf.copy_interleaved_ref(decoded); |
| 96 |
all_samples.extend_from_slice(sample_buf.samples()); |
| 97 |
} |
| 98 |
|
| 99 |
if all_samples.is_empty() { |
| 100 |
return Err(AnalysisError::NoAudioData.into()); |
| 101 |
} |
| 102 |
|
| 103 |
Ok(DecodedMultichannel { |
| 104 |
samples: all_samples, |
| 105 |
sample_rate: source_sample_rate, |
| 106 |
channels: source_channels, |
| 107 |
}) |
| 108 |
} |
| 109 |
|
| 110 |
#[cfg(test)] |
| 111 |
mod tests { |
| 112 |
use super::*; |
| 113 |
use std::io::Write; |
| 114 |
use std::path::PathBuf; |
| 115 |
|
| 116 |
|
| 117 |
fn write_wav(path: &Path, channels: u16, sample_rate: u32, samples: &[f32]) { |
| 118 |
let bytes_per_sample = 4u16; |
| 119 |
let block_align = channels * bytes_per_sample; |
| 120 |
let data_size = (samples.len() as u32) * 4; |
| 121 |
let file_size = 36 + data_size; |
| 122 |
|
| 123 |
let mut buf = Vec::with_capacity(44 + data_size as usize); |
| 124 |
buf.extend_from_slice(b"RIFF"); |
| 125 |
buf.extend_from_slice(&file_size.to_le_bytes()); |
| 126 |
buf.extend_from_slice(b"WAVE"); |
| 127 |
buf.extend_from_slice(b"fmt "); |
| 128 |
buf.extend_from_slice(&16u32.to_le_bytes()); |
| 129 |
buf.extend_from_slice(&3u16.to_le_bytes()); |
| 130 |
buf.extend_from_slice(&channels.to_le_bytes()); |
| 131 |
buf.extend_from_slice(&sample_rate.to_le_bytes()); |
| 132 |
buf.extend_from_slice(&(sample_rate * block_align as u32).to_le_bytes()); |
| 133 |
buf.extend_from_slice(&block_align.to_le_bytes()); |
| 134 |
buf.extend_from_slice(&(bytes_per_sample * 8).to_le_bytes()); |
| 135 |
buf.extend_from_slice(b"data"); |
| 136 |
buf.extend_from_slice(&data_size.to_le_bytes()); |
| 137 |
for &s in samples { |
| 138 |
buf.extend_from_slice(&s.to_le_bytes()); |
| 139 |
} |
| 140 |
|
| 141 |
let mut file = std::fs::File::create(path).unwrap(); |
| 142 |
file.write_all(&buf).unwrap(); |
| 143 |
} |
| 144 |
|
| 145 |
#[test] |
| 146 |
fn decode_preserves_mono() { |
| 147 |
let dir = tempfile::tempdir().unwrap(); |
| 148 |
let path = dir.path().join("mono.wav"); |
| 149 |
write_wav(&path, 1, 44100, &[0.5, -0.5, 0.25]); |
| 150 |
|
| 151 |
let decoded = decode_multichannel(&path).unwrap(); |
| 152 |
assert_eq!(decoded.channels, 1); |
| 153 |
assert_eq!(decoded.samples, vec![0.5, -0.5, 0.25]); |
| 154 |
assert_eq!(decoded.sample_rate, 44100); |
| 155 |
} |
| 156 |
|
| 157 |
#[test] |
| 158 |
fn decode_preserves_stereo() { |
| 159 |
let dir = tempfile::tempdir().unwrap(); |
| 160 |
let path = dir.path().join("stereo.wav"); |
| 161 |
let samples = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]; |
| 162 |
write_wav(&path, 2, 48000, &samples); |
| 163 |
|
| 164 |
let decoded = decode_multichannel(&path).unwrap(); |
| 165 |
assert_eq!(decoded.channels, 2); |
| 166 |
assert_eq!(decoded.samples, samples); |
| 167 |
assert_eq!(decoded.sample_rate, 48000); |
| 168 |
} |
| 169 |
|
| 170 |
#[test] |
| 171 |
fn decode_nonexistent_returns_error() { |
| 172 |
let result = decode_multichannel(&PathBuf::from("/nonexistent/audio.wav")); |
| 173 |
assert!(result.is_err()); |
| 174 |
} |
| 175 |
} |
| 176 |
|