using System; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; using System.Linq; using System.Text; using MathNet.Numerics; using MathNet.Numerics.LinearAlgebra; using MathNet.Numerics.LinearAlgebra.Generic; using MathNet.Numerics.Statistics; namespace KdTree { /// /// Represents a kd-tree data structure, used for the binary partitioning of k-dimensional space. /// /// The type of the value of the contained nodes. /// The type of the underlying field of the vector space of locations. public class KdTree : ICollection where TField : struct, IEquatable, IFormattable { // Default location getter function. private static readonly Func> defaultLocationGetter = el => (Vector)(object)el; // Compares values of underlying field of vector space of locations. private static readonly IComparer fieldComparer = Comparer.Default; // Performs arithmetic on values of field type. private static readonly IArithmetic fieldArithmetic = Arithmetic.Default; /// /// Constructs a kd-tree from the specified collection of elements. /// /// The dimensionality of each location vector. /// The collection of elements from which to create the tree. /// A function that gets the location of a given element in kd-space. /// The root node of the constructed kd-tree. public static KdTree Construct(int dimensionality, IEnumerable elements, Func> locationGetter = null) { if (elements == null) throw new ArgumentNullException("elements"); // Create and initialize kd-tree. var tree = new KdTree(); tree.dimensionality = dimensionality; if (locationGetter != null) tree.locationGetter = locationGetter; // Construct nodes of tree. var elementsArray = elements.ToArray(); tree.root = Construct(tree, elementsArray, elementsArray.GetLowerBound(0), elementsArray.GetUpperBound(0), 0, new ValueLocationComparer(tree.locationGetter)); return tree; } private static KdTreeNode Construct(KdTree tree, TValue[] elements, int startIndex, int endIndex, int depth, ValueLocationComparer valueLocationComparer) { var length = endIndex - startIndex + 1; if (length == 0) return null; // Sort array of elements by component of chosen dimension, in ascending magnitude. valueLocationComparer.Dimension = depth % tree.dimensionality; Array.Sort(elements, startIndex, length, valueLocationComparer); // Select median element as pivot. var medianIndex = startIndex + length / 2; var medianElement = elements[medianIndex]; // Create node and construct sub-trees around pivot element. var node = new KdTreeNode(medianElement); node.LeftChild = Construct(tree, elements, startIndex, medianIndex - 1, depth + 1, valueLocationComparer); node.RightChild = Construct(tree, elements, medianIndex + 1, endIndex, depth + 1, valueLocationComparer); return node; } // Dimensionality of location vectors. private int dimensionality; // Root node of tree. private KdTreeNode root; // Function that returns location vector of given element. private Func> locationGetter; /// /// Initializes a new instance of the class with the specified root node. /// /// The dimensionality of the kd-space. /// The root node of the tree. /// A function that returns the location of a given element in kd-space. public KdTree(int dimensionality, KdTreeNode root, Func> locationGetter = null) : this() { if (root != null) throw new ArgumentNullException("root"); this.dimensionality = dimensionality; this.root = root; if (locationGetter != null) this.locationGetter = locationGetter; } private KdTree() { this.locationGetter = defaultLocationGetter; } /// /// Gets the root node of the tree. /// /// The root node. public KdTreeNode Root { get { return this.root; } } /// /// Gets the function that returns the location of a given element in kd-space. /// /// The location getter function. public Func> LocationGetter { get { return this.locationGetter; } } /// /// Finds all nodes in the tree that lie within the specified range of a location. /// /// The location for which to find the nearest node. /// The range in which to search for nodes. /// A collection of nodes with distance from less than /// . public IEnumerable FindInRange(Vector location, TField range) { if (location == null) throw new ArgumentNullException("location"); var nodesList = new List(); FindInRange(location, this.root, range, nodesList, 0); return nodesList.AsReadOnly(); } private void FindInRange(Vector location, KdTreeNode node, TField range, IList valuesList, int depth) { if (node == null) return; var dimension = depth % this.dimensionality; var nodeLocation = this.locationGetter(node.Value); var distance = (nodeLocation - location).Norm(this.dimensionality); // Add current node to list if it lies within given range. // Current node cannot be same as search location. if (!fieldArithmetic.AlmostEqual(distance, fieldArithmetic.Zero) && fieldComparer.Compare(distance, range) < 0) { valuesList.Add(node.Value); } // Check for nodes in sub-tree of near child. var nearChildNode = fieldComparer.Compare(location[dimension], nodeLocation[dimension]) < 0 ? node.LeftChild : node.RightChild; if (nearChildNode != null) { FindInRange(location, nearChildNode, range, valuesList, depth + 1); } // Check whether splitting hyperplane given by current node intersects with hypersphere of current smallest // distance around given location. if (fieldComparer.Compare(range, fieldArithmetic.Abs(fieldArithmetic.Subtract( nodeLocation[dimension], location[dimension]))) > 0) { // Check for nodes in sub-tree of far child. var farChildNode = nearChildNode == node.LeftChild ? node.RightChild : node.LeftChild; if (farChildNode != null) { FindInRange(location, farChildNode, range, valuesList, depth + 1); } } } /// /// Finds the N values in the tree that are nearest to the specified location. /// /// The location for which to find the N nearest neighbors. /// N, the number of nearest neighbors to find. /// The N values whose locations are nearest to . public C5.IPriorityQueue FindNearestNNeighbors(Vector location, int numNeighbors) { if (location == null) throw new ArgumentNullException("location"); var nodesList = new C5.IntervalHeap(numNeighbors, new ValueDistanceComparer(this.locationGetter, this.dimensionality, location)); var minBestValue = this.root.Value; var minBestDistance = fieldArithmetic.MaxValue; FindNearestNNeighbors(location, this.root, ref minBestValue, ref minBestDistance, numNeighbors, nodesList, 0); return nodesList; } private void FindNearestNNeighbors(Vector location, KdTreeNode node, ref TValue maxBestValue, ref TField maxBestDistance, int numNeighbors, C5.IPriorityQueue valuesList, int depth) { if (node == null) return; var dimension = depth % this.dimensionality; var nodeLocation = this.locationGetter(node.Value); var distance = (nodeLocation - location).Norm(this.dimensionality); // Check if current node is better than maximum best node, and replace maximum node in list with it. // Current node cannot be same as search location. if (!fieldArithmetic.AlmostEqual(distance, fieldArithmetic.Zero) && fieldComparer.Compare(distance, maxBestDistance) < 0) { TValue maxValue; if (valuesList.Count == numNeighbors) maxValue = valuesList.DeleteMax(); valuesList.Add(node.Value); if (valuesList.Count == numNeighbors) { maxBestValue = valuesList.FindMax(); maxBestDistance = (this.locationGetter(maxBestValue) - location).Norm(this.dimensionality); } } // Check for best node in sub-tree of near child. var nearChildNode = fieldComparer.Compare(location[dimension], nodeLocation[dimension]) < 0 ? node.LeftChild : node.RightChild; if (nearChildNode != null) { FindNearestNNeighbors(location, nearChildNode, ref maxBestValue, ref maxBestDistance, numNeighbors, valuesList, depth + 1); } // Check whether splitting hyperplane given by current node intersects with hypersphere of current smallest // distance around given location. if (fieldComparer.Compare(maxBestDistance, fieldArithmetic.Abs(fieldArithmetic.Subtract( nodeLocation[dimension], location[dimension]))) > 0) { // Check for best node in sub-tree of far child. var farChildValue = nearChildNode == node.LeftChild ? node.RightChild : node.LeftChild; if (farChildValue != null) { FindNearestNNeighbors(location, farChildValue, ref maxBestValue, ref maxBestDistance, numNeighbors, valuesList, depth + 1); } } } /// /// Finds the value in the tree that is nearest to the specified location. /// /// The location for which to find the nearest neighbor. /// The value whose location is nearest to . public TValue FindNearestNeighbor(Vector location) { if (location == null) throw new ArgumentNullException("location"); return FindNearestNeighbor(location, this.root, this.root.Value, fieldArithmetic.MaxValue, 0); } private TValue FindNearestNeighbor(Vector location, KdTreeNode node, TValue bestValue, TField bestDistance, int depth) { if (node == null) return bestValue; var dimension = depth % this.dimensionality; var nodeLocation = this.locationGetter(node.Value); var distance = (nodeLocation - location).Norm(this.dimensionality); // Check if current node is better than best node. // Current node cannot be same as search location. if (!fieldArithmetic.AlmostEqual(distance, fieldArithmetic.Zero) && fieldComparer.Compare(distance, bestDistance) < 0) { bestValue = node.Value; bestDistance = distance; } // Check for best node in sub-tree of near child. var nearChildNode = fieldComparer.Compare(location[dimension], nodeLocation[dimension]) < 0 ? node.LeftChild : node.RightChild; if (nearChildNode != null) { var nearBestValue = FindNearestNeighbor(location, nearChildNode, bestValue, bestDistance, depth + 1); var nearBestLocation = this.locationGetter(nearBestValue); var nearBestDistance = (nearBestLocation - location).Norm(this.dimensionality); bestValue = nearBestValue; bestDistance = nearBestDistance; } // Check whether splitting hyperplane given by current node intersects with hypersphere of current smallest // distance around given location. if (fieldComparer.Compare(bestDistance, fieldArithmetic.Abs(fieldArithmetic.Subtract( nodeLocation[dimension], location[dimension]))) > 0) { // Check for best node in sub-tree of far child. var farChildValue = nearChildNode == node.LeftChild ? node.RightChild : node.LeftChild; if (farChildValue != null) { var farBestValue = FindNearestNeighbor(location, farChildValue, bestValue, bestDistance, depth + 1); var farBestLocation = this.locationGetter(farBestValue); var farBestDistance = (farBestLocation - location).Norm(this.dimensionality); bestValue = farBestValue; bestDistance = farBestDistance; } } return bestValue; } /// /// Adds a node with the specified value to the tree. /// /// The value of the element to add. /// The node that was added. /// /// Nodes with duplicate values may be added to the tree. /// public KdTreeNode Add(TValue value) { if (value == null) throw new ArgumentNullException("value"); return Add(value, this.root, 0); } private KdTreeNode Add(TValue value, KdTreeNode node, int depth) { if (node == null) { node = new KdTreeNode(value); } else { // Check if node should be added to left or right sub-tree of current node. var dimension = depth % this.dimensionality; var comparison = fieldComparer.Compare(this.locationGetter(value)[dimension], this.locationGetter(node.Value)[dimension]); if (comparison <= 0) { node.LeftChild = Add(value, node.LeftChild, depth + 1); } else { node.RightChild = Add(value, node.RightChild, depth + 1); } } return node; } /// /// Removes the node with the specified value from the tree. /// /// The value of the node to remove. /// The node that was removed, or if none was found. public KdTreeNode Remove(TValue value) { if (value == null) throw new ArgumentNullException("value"); return Remove(value, this.root, 0); } private KdTreeNode Remove(TValue value, KdTreeNode node, int depth) { if (node == null) return null; var dimension = depth % this.dimensionality; var valueLocation = this.locationGetter(value); var nodeLocation = this.locationGetter(node.Value); var comparison = fieldComparer.Compare(valueLocation[dimension], nodeLocation[dimension]); // Check if node to remove is in left sub-tree, right sub-tree, or has been found. if (comparison < 0) { node.LeftChild = Remove(value, node.LeftChild, depth + 1); } else if (comparison > 0) { node.RightChild = Remove(value, node.RightChild, depth + 1); } else { if (node.RightChild != null) { node.Value = FindMinimum(node.RightChild, dimension, depth + 1); node.RightChild = Remove(node.Value, node.RightChild, depth + 1); } else if (node.LeftChild != null) { node.Value = FindMinimum(node.LeftChild, dimension, depth + 1); node.RightChild = Remove(node.Value, node.LeftChild, depth + 1); node.LeftChild = null; } else { node = null; } } return node; } /// /// Removes all nodes in the tree except for the root node. /// public void Clear() { this.root.LeftChild = null; this.root.RightChild = null; } /// /// Determines whether the specified value is the value of any node in the tree. /// /// The value to locate in the tree. /// if was found in the tree; /// , otherwise. public bool Contains(TValue value) { return Find(value) != null; } /// /// Finds the node with the specified value. /// /// The value of the node to remove. public KdTreeNode Find(TValue value) { if (value == null) throw new ArgumentNullException("value"); return Find(value, this.root); } private KdTreeNode Find(TValue value, KdTreeNode node) { if (node == null) return null; if (node.Value.Equals(value)) return node; return Find(value, node.LeftChild) ?? Find(value, node.RightChild); } private TValue FindMinimum(KdTreeNode node, int splittingDimension, int depth) { if (node == null) return default(TValue); var dimension = depth % this.dimensionality; if (dimension == splittingDimension) { // Find minimum value in left sub-tree. if (node.LeftChild == null) return node.Value; else return FindMinimum(node.LeftChild, splittingDimension, depth + 1); } else { // Find node with minimum value in sub-tree of current node. var nodeLocation = this.locationGetter(node.Value); var leftMinValue = FindMinimum(node.LeftChild, splittingDimension, depth + 1); var rightMinValue = FindMinimum(node.RightChild, splittingDimension, depth + 1); var leftMinValueBetter = leftMinValue != null && fieldComparer.Compare(this.locationGetter(leftMinValue)[splittingDimension], nodeLocation[splittingDimension]) < 0; var rightMinValueBetter = rightMinValue != null && fieldComparer.Compare(this.locationGetter(rightMinValue)[splittingDimension], nodeLocation[splittingDimension]) < 0; if (leftMinValueBetter && !rightMinValueBetter) return leftMinValue; else if (rightMinValueBetter) return rightMinValue; else return node.Value; } } #region ICollection Members bool ICollection.IsReadOnly { get { return false; } } int ICollection.Count { get { throw new NotSupportedException(); } } void ICollection.Add(TValue item) { Add(item); } bool ICollection.Remove(TValue item) { return Remove(item) != null; } /// /// Copies the values of all the nodes in the tree to the specified array, starting at the specified index. /// /// The array that is the destination of the copied elements. /// The zero-based index in at which copying begins. public void CopyTo(TValue[] array, int arrayIndex) { var enumerator = GetEnumerator(); var index = arrayIndex; while (enumerator.MoveNext()) array[index++] = enumerator.Current; } #endregion #region IEnumerable Members /// /// Returns an enumerator that iterates through the nodes in the tree. /// /// An enumerator for the tree. public IEnumerator GetEnumerator() { // Perform breadth-first search of tree, yielding every node found. var visitedNodes = new Stack>(); visitedNodes.Push(this.root); while (visitedNodes.Count > 0) { var node = visitedNodes.Pop(); yield return node.Value; if (node.LeftChild != null) visitedNodes.Push(node.LeftChild); if (node.RightChild != null) visitedNodes.Push(node.RightChild); } } #endregion #region IEnumerable Members IEnumerator IEnumerable.GetEnumerator() { return ((IEnumerable)this).GetEnumerator(); } #endregion // Compares node values by their locations. private class ValueLocationComparer : Comparer { // Function that returns location of element in kd-space. private Func> locationGetter; public ValueLocationComparer(Func> locationGetter) { Debug.Assert(locationGetter != null); this.locationGetter = locationGetter; } // Index of dimension of components to compare. public int Dimension { get; set; } public override int Compare(TValue x, TValue y) { return fieldComparer.Compare( this.locationGetter(x)[this.Dimension], this.locationGetter(y)[this.Dimension]); } } // Compares node values by their distances. private class ValueDistanceComparer : Comparer { // Function that returns location of element in kd-space. public Func> locationGetter; // Dimensionality of kd-tree. private int dimensionality; // Location from which to calculate distance. private Vector location; public ValueDistanceComparer(Func> locationGetter, int dimensionality, Vector location) { Debug.Assert(locationGetter != null); this.locationGetter = locationGetter; this.dimensionality = dimensionality; this.location = location; } public override int Compare(TValue x, TValue y) { return fieldComparer.Compare( (this.locationGetter(x) - location).Norm(this.dimensionality), (this.locationGetter(y) - location).Norm(this.dimensionality)); } } } }