//! Vantage-point tree for sub-linear nearest-neighbor and range queries in metric spaces. //! //! A VP-tree partitions data by distance from selected "vantage points", enabling //! efficient pruning during search. Requires a distance function satisfying the //! triangle inequality. //! //! - Build: O(n log n) distance computations //! - k-nearest query: O(log n) expected distance computations //! - Range query: O(log n + results) expected distance computations use std::collections::BinaryHeap; /// A vantage-point tree for nearest-neighbor and range queries. pub struct VpTree { items: Vec, root: Option, nodes: Vec, } struct VpNode { /// Index into `items`. item_idx: usize, /// Median distance from this vantage point to items in its subtree. threshold: f64, /// Inside subtree: items with distance <= threshold. left: Option, /// Outside subtree: items with distance > threshold. right: Option, } /// A search result: index into the original items vector plus distance from the query. #[derive(Debug, Clone)] pub struct VpMatch { pub index: usize, pub distance: f64, } impl VpTree { /// Build a VP-tree from a set of items. /// /// The distance function must be a metric (non-negative, symmetric, triangle /// inequality). Build cost is O(n log n) distance computations. pub fn build(items: Vec, dist: impl Fn(&T, &T) -> f64) -> Self { let n = items.len(); if n == 0 { return Self { items, root: None, nodes: Vec::new(), }; } let mut indices: Vec = (0..n).collect(); let mut nodes = Vec::with_capacity(n); let root = build_recursive(&items, &mut indices, &dist, &mut nodes, 0); Self { items, root: Some(root), nodes, } } /// Find the `k` nearest items to `query`, sorted by ascending distance. pub fn find_nearest( &self, query: &T, k: usize, dist: impl Fn(&T, &T) -> f64, ) -> Vec { if self.items.is_empty() || k == 0 { return Vec::new(); } let mut heap: BinaryHeap = BinaryHeap::with_capacity(k + 1); let mut tau = f64::INFINITY; if let Some(root) = self.root { search_nearest( &self.items, &self.nodes, root, query, k, &dist, &mut heap, &mut tau, ); } let mut results: Vec = heap .into_iter() .map(|e| VpMatch { index: e.index, distance: e.distance, }) .collect(); results.sort_by(|a, b| a.distance.total_cmp(&b.distance)); results } /// Find all items within `radius` of `query`, sorted by ascending distance. pub fn find_within( &self, query: &T, radius: f64, dist: impl Fn(&T, &T) -> f64, ) -> Vec { if self.items.is_empty() { return Vec::new(); } let mut results = Vec::new(); if let Some(root) = self.root { search_within( &self.items, &self.nodes, root, query, radius, &dist, &mut results, ); } results.sort_by(|a, b| a.distance.total_cmp(&b.distance)); results } /// Number of items in the tree. pub fn len(&self) -> usize { self.items.len() } /// Whether the tree is empty. pub fn is_empty(&self) -> bool { self.items.is_empty() } /// Reference to an item by index. pub fn get(&self, index: usize) -> &T { &self.items[index] } } // --- Build --- /// Maximum recursion depth to prevent stack overflow on degenerate input /// (e.g. all-identical feature vectors causing O(n)-deep recursion). const MAX_BUILD_DEPTH: usize = 64; fn build_recursive( items: &[T], indices: &mut [usize], dist: &impl Fn(&T, &T) -> f64, nodes: &mut Vec, depth: usize, ) -> usize { debug_assert!(!indices.is_empty()); let vp_idx = indices[0]; // Leaf node: single item if indices.len() == 1 { let node_idx = nodes.len(); nodes.push(VpNode { item_idx: vp_idx, threshold: 0.0, left: None, right: None, }); return node_idx; } // Depth cap: chain remaining items as a flat linked list (iteratively) // so none are lost. Each node holds one item, threshold=INFINITY ensures // the search always visits the left child. if depth >= MAX_BUILD_DEPTH { let first_node_idx = nodes.len(); // Create all leaf nodes for &idx in indices.iter() { nodes.push(VpNode { item_idx: idx, threshold: f64::INFINITY, left: None, right: None, }); } // Link them: each node's left points to the next for i in 0..indices.len() - 1 { nodes[first_node_idx + i].left = Some(first_node_idx + i + 1); } return first_node_idx; } // Compute distances from vantage point to all other items in this subtree. let mut dists: Vec<(usize, f64)> = indices[1..] .iter() .map(|&idx| (idx, dist(&items[vp_idx], &items[idx]))) .collect(); // Partition at median distance. let median_pos = dists.len() / 2; dists.select_nth_unstable_by(median_pos, |a, b| a.1.total_cmp(&b.1)); let threshold = dists[median_pos].1; // Split into inside (<= threshold) and outside (> threshold). let mut inside = Vec::with_capacity(median_pos + 1); let mut outside = Vec::with_capacity(dists.len() - median_pos); for &(idx, d) in &dists { if d <= threshold { inside.push(idx); } else { outside.push(idx); } } // Allocate node, fill children after recursion. let node_idx = nodes.len(); nodes.push(VpNode { item_idx: vp_idx, threshold, left: None, right: None, }); let left = if inside.is_empty() { None } else { Some(build_recursive(items, &mut inside, dist, nodes, depth + 1)) }; let right = if outside.is_empty() { None } else { Some(build_recursive(items, &mut outside, dist, nodes, depth + 1)) }; nodes[node_idx].left = left; nodes[node_idx].right = right; node_idx } // --- k-nearest search --- struct HeapEntry { distance: f64, index: usize, } impl Eq for HeapEntry {} impl PartialEq for HeapEntry { fn eq(&self, other: &Self) -> bool { self.distance.total_cmp(&other.distance) == std::cmp::Ordering::Equal } } impl PartialOrd for HeapEntry { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for HeapEntry { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.distance.total_cmp(&other.distance) } } #[allow(clippy::too_many_arguments)] fn search_nearest( items: &[T], nodes: &[VpNode], node_idx: usize, query: &T, k: usize, dist: &impl Fn(&T, &T) -> f64, heap: &mut BinaryHeap, tau: &mut f64, ) { let node = &nodes[node_idx]; let d = dist(query, &items[node.item_idx]); // Consider this node's vantage point. if d < *tau || heap.len() < k { heap.push(HeapEntry { distance: d, index: node.item_idx, }); if heap.len() > k { heap.pop(); } if heap.len() == k { *tau = heap.peek().unwrap().distance; } } // Search the closer subtree first for better pruning. if d <= node.threshold { // Query is inside — search inside first. if let Some(left) = node.left && d - *tau <= node.threshold { search_nearest(items, nodes, left, query, k, dist, heap, tau); } if let Some(right) = node.right && d + *tau > node.threshold { search_nearest(items, nodes, right, query, k, dist, heap, tau); } } else { // Query is outside — search outside first. if let Some(right) = node.right && d + *tau > node.threshold { search_nearest(items, nodes, right, query, k, dist, heap, tau); } if let Some(left) = node.left && d - *tau <= node.threshold { search_nearest(items, nodes, left, query, k, dist, heap, tau); } } } // --- Range search --- fn search_within( items: &[T], nodes: &[VpNode], node_idx: usize, query: &T, radius: f64, dist: &impl Fn(&T, &T) -> f64, results: &mut Vec, ) { let node = &nodes[node_idx]; let d = dist(query, &items[node.item_idx]); if d <= radius { results.push(VpMatch { index: node.item_idx, distance: d, }); } // Prune subtrees using triangle inequality bounds. if let Some(left) = node.left && d - radius <= node.threshold { search_within(items, nodes, left, query, radius, dist, results); } if let Some(right) = node.right && d + radius > node.threshold { search_within(items, nodes, right, query, radius, dist, results); } } #[cfg(test)] mod tests { use super::*; fn euclidean_1d(a: &f64, b: &f64) -> f64 { (a - b).abs() } fn euclidean_2d(a: &[f64; 2], b: &[f64; 2]) -> f64 { ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt() } #[test] fn empty_tree() { let tree: VpTree = VpTree::build(vec![], euclidean_1d); assert!(tree.is_empty()); assert_eq!(tree.len(), 0); assert!(tree.find_nearest(&0.0, 5, euclidean_1d).is_empty()); assert!(tree.find_within(&0.0, 1.0, euclidean_1d).is_empty()); } #[test] fn single_item() { let tree = VpTree::build(vec![5.0], euclidean_1d); assert_eq!(tree.len(), 1); let results = tree.find_nearest(&5.0, 1, euclidean_1d); assert_eq!(results.len(), 1); assert!(results[0].distance.abs() < f64::EPSILON); } #[test] fn k_nearest_basic() { let items = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_nearest(&6.0, 3, euclidean_1d); assert_eq!(results.len(), 3); let values: Vec = results.iter().map(|r| *tree.get(r.index)).collect(); assert!(values.contains(&5.0)); assert!(values.contains(&7.0)); } #[test] fn find_within_basic() { let items = vec![1.0, 3.0, 5.0, 7.0, 9.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_within(&5.0, 2.5, euclidean_1d); let values: Vec = results.iter().map(|r| *tree.get(r.index)).collect(); assert!(values.contains(&3.0)); assert!(values.contains(&5.0)); assert!(values.contains(&7.0)); assert_eq!(values.len(), 3); } #[test] fn results_sorted_by_distance() { let items: Vec = (0..100).map(|i| i as f64).collect(); let tree = VpTree::build(items, euclidean_1d); let nearest = tree.find_nearest(&50.0, 10, euclidean_1d); for w in nearest.windows(2) { assert!(w[0].distance <= w[1].distance); } let within = tree.find_within(&50.0, 5.0, euclidean_1d); for w in within.windows(2) { assert!(w[0].distance <= w[1].distance); } } #[test] fn k_larger_than_n() { let items = vec![1.0, 2.0, 3.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_nearest(&0.0, 10, euclidean_1d); assert_eq!(results.len(), 3); } #[test] fn k_zero() { let items = vec![1.0, 2.0, 3.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_nearest(&0.0, 0, euclidean_1d); assert!(results.is_empty()); } #[test] fn two_dimensional() { let items = vec![ [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [10.0, 10.0], ]; let tree = VpTree::build(items, euclidean_2d); let results = tree.find_nearest(&[0.5, 0.5], 4, euclidean_2d); assert_eq!(results.len(), 4); // [10, 10] should not be in top-4 assert!(!results.iter().any(|r| tree.get(r.index)[0] > 5.0)); } #[test] fn find_within_excludes_far_items() { let items = vec![0.0, 1.0, 2.0, 100.0, 200.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_within(&1.0, 1.5, euclidean_1d); assert_eq!(results.len(), 3); // 0.0, 1.0, 2.0 assert!(results.iter().all(|r| *tree.get(r.index) <= 2.5)); } #[test] fn correctness_vs_brute_force() { // 1000 deterministic pseudo-random points. let items: Vec = (0..1000) .map(|i| ((i as f64 * 0.618033988749895) % 1.0) * 100.0) .collect(); let tree = VpTree::build(items.clone(), euclidean_1d); let query = 42.0; // k-nearest let k = 10; let tree_results = tree.find_nearest(&query, k, euclidean_1d); let mut brute: Vec<(usize, f64)> = items .iter() .enumerate() .map(|(i, &v)| (i, (v - query).abs())) .collect(); brute.sort_by(|a, b| a.1.total_cmp(&b.1)); assert_eq!(tree_results.len(), k); for (tr, br) in tree_results.iter().zip(brute.iter()) { assert!( (tr.distance - br.1).abs() < 1e-10, "VP-tree distance {} != brute force distance {}", tr.distance, br.1 ); } // find_within let radius = 5.0; let tree_within = tree.find_within(&query, radius, euclidean_1d); let brute_within: Vec = items .iter() .filter(|&&v| (v - query).abs() <= radius) .copied() .collect(); assert_eq!( tree_within.len(), brute_within.len(), "VP-tree found {} items within radius, brute force found {}", tree_within.len(), brute_within.len() ); } #[test] fn two_items() { let tree = VpTree::build(vec![0.0, 10.0], euclidean_1d); let results = tree.find_nearest(&3.0, 1, euclidean_1d); assert_eq!(results.len(), 1); assert!((tree.get(results[0].index) - 0.0).abs() < f64::EPSILON); } #[test] fn duplicate_distances() { // Multiple items at the same distance from each other. let items = vec![0.0, 5.0, 5.0, 5.0, 10.0]; let tree = VpTree::build(items, euclidean_1d); let results = tree.find_within(&5.0, 0.0, euclidean_1d); assert_eq!(results.len(), 3); // three items at distance 0 } }