Skip to main content

max / audiofiles

15.1 KB · 516 lines History Blame Raw
1 //! Vantage-point tree for sub-linear nearest-neighbor and range queries in metric spaces.
2 //!
3 //! A VP-tree partitions data by distance from selected "vantage points", enabling
4 //! efficient pruning during search. Requires a distance function satisfying the
5 //! triangle inequality.
6 //!
7 //! - Build: O(n log n) distance computations
8 //! - k-nearest query: O(log n) expected distance computations
9 //! - Range query: O(log n + results) expected distance computations
10
11 use std::collections::BinaryHeap;
12
13 /// A vantage-point tree for nearest-neighbor and range queries.
14 pub struct VpTree<T> {
15 items: Vec<T>,
16 root: Option<usize>,
17 nodes: Vec<VpNode>,
18 }
19
20 struct VpNode {
21 /// Index into `items`.
22 item_idx: usize,
23 /// Median distance from this vantage point to items in its subtree.
24 threshold: f64,
25 /// Inside subtree: items with distance <= threshold.
26 left: Option<usize>,
27 /// Outside subtree: items with distance > threshold.
28 right: Option<usize>,
29 }
30
31 /// A search result: index into the original items vector plus distance from the query.
32 #[derive(Debug, Clone)]
33 pub struct VpMatch {
34 pub index: usize,
35 pub distance: f64,
36 }
37
38 impl<T> VpTree<T> {
39 /// Build a VP-tree from a set of items.
40 ///
41 /// The distance function must be a metric (non-negative, symmetric, triangle
42 /// inequality). Build cost is O(n log n) distance computations.
43 pub fn build(items: Vec<T>, dist: impl Fn(&T, &T) -> f64) -> Self {
44 let n = items.len();
45 if n == 0 {
46 return Self {
47 items,
48 root: None,
49 nodes: Vec::new(),
50 };
51 }
52 let mut indices: Vec<usize> = (0..n).collect();
53 let mut nodes = Vec::with_capacity(n);
54 let root = build_recursive(&items, &mut indices, &dist, &mut nodes, 0);
55 Self {
56 items,
57 root: Some(root),
58 nodes,
59 }
60 }
61
62 /// Find the `k` nearest items to `query`, sorted by ascending distance.
63 pub fn find_nearest(
64 &self,
65 query: &T,
66 k: usize,
67 dist: impl Fn(&T, &T) -> f64,
68 ) -> Vec<VpMatch> {
69 if self.items.is_empty() || k == 0 {
70 return Vec::new();
71 }
72 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
73 let mut tau = f64::INFINITY;
74 if let Some(root) = self.root {
75 search_nearest(
76 &self.items,
77 &self.nodes,
78 root,
79 query,
80 k,
81 &dist,
82 &mut heap,
83 &mut tau,
84 );
85 }
86 let mut results: Vec<VpMatch> = heap
87 .into_iter()
88 .map(|e| VpMatch {
89 index: e.index,
90 distance: e.distance,
91 })
92 .collect();
93 results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
94 results
95 }
96
97 /// Find all items within `radius` of `query`, sorted by ascending distance.
98 pub fn find_within(
99 &self,
100 query: &T,
101 radius: f64,
102 dist: impl Fn(&T, &T) -> f64,
103 ) -> Vec<VpMatch> {
104 if self.items.is_empty() {
105 return Vec::new();
106 }
107 let mut results = Vec::new();
108 if let Some(root) = self.root {
109 search_within(
110 &self.items,
111 &self.nodes,
112 root,
113 query,
114 radius,
115 &dist,
116 &mut results,
117 );
118 }
119 results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
120 results
121 }
122
123 /// Number of items in the tree.
124 pub fn len(&self) -> usize {
125 self.items.len()
126 }
127
128 /// Whether the tree is empty.
129 pub fn is_empty(&self) -> bool {
130 self.items.is_empty()
131 }
132
133 /// Reference to an item by index.
134 pub fn get(&self, index: usize) -> &T {
135 &self.items[index]
136 }
137 }
138
139 // --- Build ---
140
141 /// Maximum recursion depth to prevent stack overflow on degenerate input
142 /// (e.g. all-identical feature vectors causing O(n)-deep recursion).
143 const MAX_BUILD_DEPTH: usize = 64;
144
145 fn build_recursive<T>(
146 items: &[T],
147 indices: &mut [usize],
148 dist: &impl Fn(&T, &T) -> f64,
149 nodes: &mut Vec<VpNode>,
150 depth: usize,
151 ) -> usize {
152 debug_assert!(!indices.is_empty());
153
154 let vp_idx = indices[0];
155 // Leaf node: single item
156 if indices.len() == 1 {
157 let node_idx = nodes.len();
158 nodes.push(VpNode {
159 item_idx: vp_idx,
160 threshold: 0.0,
161 left: None,
162 right: None,
163 });
164 return node_idx;
165 }
166
167 // Depth cap: chain remaining items as a flat linked list (iteratively)
168 // so none are lost. Each node holds one item, threshold=INFINITY ensures
169 // the search always visits the left child.
170 if depth >= MAX_BUILD_DEPTH {
171 let first_node_idx = nodes.len();
172 // Create all leaf nodes
173 for &idx in indices.iter() {
174 nodes.push(VpNode {
175 item_idx: idx,
176 threshold: f64::INFINITY,
177 left: None,
178 right: None,
179 });
180 }
181 // Link them: each node's left points to the next
182 for i in 0..indices.len() - 1 {
183 nodes[first_node_idx + i].left = Some(first_node_idx + i + 1);
184 }
185 return first_node_idx;
186 }
187
188 // Compute distances from vantage point to all other items in this subtree.
189 let mut dists: Vec<(usize, f64)> = indices[1..]
190 .iter()
191 .map(|&idx| (idx, dist(&items[vp_idx], &items[idx])))
192 .collect();
193
194 // Partition at median distance.
195 let median_pos = dists.len() / 2;
196 dists.select_nth_unstable_by(median_pos, |a, b| a.1.total_cmp(&b.1));
197 let threshold = dists[median_pos].1;
198
199 // Split into inside (<= threshold) and outside (> threshold).
200 let mut inside = Vec::with_capacity(median_pos + 1);
201 let mut outside = Vec::with_capacity(dists.len() - median_pos);
202 for &(idx, d) in &dists {
203 if d <= threshold {
204 inside.push(idx);
205 } else {
206 outside.push(idx);
207 }
208 }
209
210 // Allocate node, fill children after recursion.
211 let node_idx = nodes.len();
212 nodes.push(VpNode {
213 item_idx: vp_idx,
214 threshold,
215 left: None,
216 right: None,
217 });
218
219 let left = if inside.is_empty() {
220 None
221 } else {
222 Some(build_recursive(items, &mut inside, dist, nodes, depth + 1))
223 };
224 let right = if outside.is_empty() {
225 None
226 } else {
227 Some(build_recursive(items, &mut outside, dist, nodes, depth + 1))
228 };
229
230 nodes[node_idx].left = left;
231 nodes[node_idx].right = right;
232 node_idx
233 }
234
235 // --- k-nearest search ---
236
237 struct HeapEntry {
238 distance: f64,
239 index: usize,
240 }
241
242 impl Eq for HeapEntry {}
243 impl PartialEq for HeapEntry {
244 fn eq(&self, other: &Self) -> bool {
245 self.distance.total_cmp(&other.distance) == std::cmp::Ordering::Equal
246 }
247 }
248 impl PartialOrd for HeapEntry {
249 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
250 Some(self.cmp(other))
251 }
252 }
253 impl Ord for HeapEntry {
254 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
255 self.distance.total_cmp(&other.distance)
256 }
257 }
258
259 #[allow(clippy::too_many_arguments)]
260 fn search_nearest<T>(
261 items: &[T],
262 nodes: &[VpNode],
263 node_idx: usize,
264 query: &T,
265 k: usize,
266 dist: &impl Fn(&T, &T) -> f64,
267 heap: &mut BinaryHeap<HeapEntry>,
268 tau: &mut f64,
269 ) {
270 let node = &nodes[node_idx];
271 let d = dist(query, &items[node.item_idx]);
272
273 // Consider this node's vantage point.
274 if d < *tau || heap.len() < k {
275 heap.push(HeapEntry {
276 distance: d,
277 index: node.item_idx,
278 });
279 if heap.len() > k {
280 heap.pop();
281 }
282 if heap.len() == k {
283 *tau = heap.peek().unwrap().distance;
284 }
285 }
286
287 // Search the closer subtree first for better pruning.
288 if d <= node.threshold {
289 // Query is inside — search inside first.
290 if let Some(left) = node.left
291 && d - *tau <= node.threshold {
292 search_nearest(items, nodes, left, query, k, dist, heap, tau);
293 }
294 if let Some(right) = node.right
295 && d + *tau > node.threshold {
296 search_nearest(items, nodes, right, query, k, dist, heap, tau);
297 }
298 } else {
299 // Query is outside — search outside first.
300 if let Some(right) = node.right
301 && d + *tau > node.threshold {
302 search_nearest(items, nodes, right, query, k, dist, heap, tau);
303 }
304 if let Some(left) = node.left
305 && d - *tau <= node.threshold {
306 search_nearest(items, nodes, left, query, k, dist, heap, tau);
307 }
308 }
309 }
310
311 // --- Range search ---
312
313 fn search_within<T>(
314 items: &[T],
315 nodes: &[VpNode],
316 node_idx: usize,
317 query: &T,
318 radius: f64,
319 dist: &impl Fn(&T, &T) -> f64,
320 results: &mut Vec<VpMatch>,
321 ) {
322 let node = &nodes[node_idx];
323 let d = dist(query, &items[node.item_idx]);
324
325 if d <= radius {
326 results.push(VpMatch {
327 index: node.item_idx,
328 distance: d,
329 });
330 }
331
332 // Prune subtrees using triangle inequality bounds.
333 if let Some(left) = node.left
334 && d - radius <= node.threshold {
335 search_within(items, nodes, left, query, radius, dist, results);
336 }
337 if let Some(right) = node.right
338 && d + radius > node.threshold {
339 search_within(items, nodes, right, query, radius, dist, results);
340 }
341 }
342
343 #[cfg(test)]
344 mod tests {
345 use super::*;
346
347 fn euclidean_1d(a: &f64, b: &f64) -> f64 {
348 (a - b).abs()
349 }
350
351 fn euclidean_2d(a: &[f64; 2], b: &[f64; 2]) -> f64 {
352 ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
353 }
354
355 #[test]
356 fn empty_tree() {
357 let tree: VpTree<f64> = VpTree::build(vec![], euclidean_1d);
358 assert!(tree.is_empty());
359 assert_eq!(tree.len(), 0);
360 assert!(tree.find_nearest(&0.0, 5, euclidean_1d).is_empty());
361 assert!(tree.find_within(&0.0, 1.0, euclidean_1d).is_empty());
362 }
363
364 #[test]
365 fn single_item() {
366 let tree = VpTree::build(vec![5.0], euclidean_1d);
367 assert_eq!(tree.len(), 1);
368 let results = tree.find_nearest(&5.0, 1, euclidean_1d);
369 assert_eq!(results.len(), 1);
370 assert!(results[0].distance.abs() < f64::EPSILON);
371 }
372
373 #[test]
374 fn k_nearest_basic() {
375 let items = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
376 let tree = VpTree::build(items, euclidean_1d);
377 let results = tree.find_nearest(&6.0, 3, euclidean_1d);
378 assert_eq!(results.len(), 3);
379 let values: Vec<f64> = results.iter().map(|r| *tree.get(r.index)).collect();
380 assert!(values.contains(&5.0));
381 assert!(values.contains(&7.0));
382 }
383
384 #[test]
385 fn find_within_basic() {
386 let items = vec![1.0, 3.0, 5.0, 7.0, 9.0];
387 let tree = VpTree::build(items, euclidean_1d);
388 let results = tree.find_within(&5.0, 2.5, euclidean_1d);
389 let values: Vec<f64> = results.iter().map(|r| *tree.get(r.index)).collect();
390 assert!(values.contains(&3.0));
391 assert!(values.contains(&5.0));
392 assert!(values.contains(&7.0));
393 assert_eq!(values.len(), 3);
394 }
395
396 #[test]
397 fn results_sorted_by_distance() {
398 let items: Vec<f64> = (0..100).map(|i| i as f64).collect();
399 let tree = VpTree::build(items, euclidean_1d);
400
401 let nearest = tree.find_nearest(&50.0, 10, euclidean_1d);
402 for w in nearest.windows(2) {
403 assert!(w[0].distance <= w[1].distance);
404 }
405
406 let within = tree.find_within(&50.0, 5.0, euclidean_1d);
407 for w in within.windows(2) {
408 assert!(w[0].distance <= w[1].distance);
409 }
410 }
411
412 #[test]
413 fn k_larger_than_n() {
414 let items = vec![1.0, 2.0, 3.0];
415 let tree = VpTree::build(items, euclidean_1d);
416 let results = tree.find_nearest(&0.0, 10, euclidean_1d);
417 assert_eq!(results.len(), 3);
418 }
419
420 #[test]
421 fn k_zero() {
422 let items = vec![1.0, 2.0, 3.0];
423 let tree = VpTree::build(items, euclidean_1d);
424 let results = tree.find_nearest(&0.0, 0, euclidean_1d);
425 assert!(results.is_empty());
426 }
427
428 #[test]
429 fn two_dimensional() {
430 let items = vec![
431 [0.0, 0.0],
432 [1.0, 0.0],
433 [0.0, 1.0],
434 [1.0, 1.0],
435 [10.0, 10.0],
436 ];
437 let tree = VpTree::build(items, euclidean_2d);
438 let results = tree.find_nearest(&[0.5, 0.5], 4, euclidean_2d);
439 assert_eq!(results.len(), 4);
440 // [10, 10] should not be in top-4
441 assert!(!results.iter().any(|r| tree.get(r.index)[0] > 5.0));
442 }
443
444 #[test]
445 fn find_within_excludes_far_items() {
446 let items = vec![0.0, 1.0, 2.0, 100.0, 200.0];
447 let tree = VpTree::build(items, euclidean_1d);
448 let results = tree.find_within(&1.0, 1.5, euclidean_1d);
449 assert_eq!(results.len(), 3); // 0.0, 1.0, 2.0
450 assert!(results.iter().all(|r| *tree.get(r.index) <= 2.5));
451 }
452
453 #[test]
454 fn correctness_vs_brute_force() {
455 // 1000 deterministic pseudo-random points.
456 let items: Vec<f64> = (0..1000)
457 .map(|i| ((i as f64 * 0.618033988749895) % 1.0) * 100.0)
458 .collect();
459 let tree = VpTree::build(items.clone(), euclidean_1d);
460 let query = 42.0;
461
462 // k-nearest
463 let k = 10;
464 let tree_results = tree.find_nearest(&query, k, euclidean_1d);
465 let mut brute: Vec<(usize, f64)> = items
466 .iter()
467 .enumerate()
468 .map(|(i, &v)| (i, (v - query).abs()))
469 .collect();
470 brute.sort_by(|a, b| a.1.total_cmp(&b.1));
471
472 assert_eq!(tree_results.len(), k);
473 for (tr, br) in tree_results.iter().zip(brute.iter()) {
474 assert!(
475 (tr.distance - br.1).abs() < 1e-10,
476 "VP-tree distance {} != brute force distance {}",
477 tr.distance,
478 br.1
479 );
480 }
481
482 // find_within
483 let radius = 5.0;
484 let tree_within = tree.find_within(&query, radius, euclidean_1d);
485 let brute_within: Vec<f64> = items
486 .iter()
487 .filter(|&&v| (v - query).abs() <= radius)
488 .copied()
489 .collect();
490 assert_eq!(
491 tree_within.len(),
492 brute_within.len(),
493 "VP-tree found {} items within radius, brute force found {}",
494 tree_within.len(),
495 brute_within.len()
496 );
497 }
498
499 #[test]
500 fn two_items() {
501 let tree = VpTree::build(vec![0.0, 10.0], euclidean_1d);
502 let results = tree.find_nearest(&3.0, 1, euclidean_1d);
503 assert_eq!(results.len(), 1);
504 assert!((tree.get(results[0].index) - 0.0).abs() < f64::EPSILON);
505 }
506
507 #[test]
508 fn duplicate_distances() {
509 // Multiple items at the same distance from each other.
510 let items = vec![0.0, 5.0, 5.0, 5.0, 10.0];
511 let tree = VpTree::build(items, euclidean_1d);
512 let results = tree.find_within(&5.0, 0.0, euclidean_1d);
513 assert_eq!(results.len(), 3); // three items at distance 0
514 }
515 }
516