K-d Trees
Introduction
K-d trees (short for k-dimensional trees) are a space-partitioning data structure that organize points in a k-dimensional space. They are particularly useful for applications that involve multidimensional search keys, such as nearest neighbor searches and range searches.
Imagine you have a large number of points in a 2D or 3D space, and you frequently need to find the closest point to a given query point. A naive approach would require checking the distance to every point, which is inefficient for large datasets. K-d trees solve this problem by partitioning the space in a way that allows for much faster searches.
Understanding K-d Trees
What is a K-d Tree?
A k-d tree is a binary tree where:
- Each node represents a point in k-dimensional space
- Each non-leaf node divides the space into two half-spaces along one dimension
- Points to the left of the node have a smaller value in the splitting dimension
- Points to the right have a greater value in the splitting dimension
The splitting dimension typically cycles through all dimensions as we move down the tree. For example, in a 2D space (k=2), the root node might split along the x-axis, its children along the y-axis, its grandchildren along the x-axis again, and so on.
Visual Representation
Let's visualize how a k-d tree partitions a 2D space:
This represents how the tree might look, but let's see how the actual space is partitioned:
Building a K-d Tree
Building a k-d tree involves recursively partitioning the point set along alternating dimensions:
- Choose a splitting dimension (typically cycling through dimensions)
- Find the median point along that dimension
- Create a node with that median point
- Recursively build the left subtree with points on the "less than" side
- Recursively build the right subtree with points on the "greater than" side
Let's implement a K-d tree in Python for 2D points:
class KdNode:
    def __init__(self, point, dim, left=None, right=None):
        self.point = point  # The point (x, y)
        self.dim = dim      # The dimension to split on (0 for x, 1 for y)
        self.left = left    # Left subtree
        self.right = right  # Right subtree
def build_kdtree(points, depth=0):
    if not points:
        return None
    
    k = len(points[0])  # Dimensionality of points
    dim = depth % k     # Current dimension to split on
    
    # Sort points based on the current dimension
    points.sort(key=lambda x: x[dim])
    
    # Find median point
    median_idx = len(points) // 2
    
    # Create node and recursively build subtrees
    return KdNode(
        point=points[median_idx],
        dim=dim,
        left=build_kdtree(points[:median_idx], depth+1),
        right=build_kdtree(points[median_idx+1:], depth+1)
    )
Let's see this in action with a simple example:
# Example points in 2D space
points = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
# Build the k-d tree
root = build_kdtree(points)
# Print tree structure
def print_tree(node, level=0):
    if node is not None:
        print("  " * level + f"({node.point[0]}, {node.point[1]}) - split on {'x' if node.dim == 0 else 'y'}")
        print_tree(node.left, level + 1)
        print_tree(node.right, level + 1)
print_tree(root)
Output:
(7, 2) - split on x
  (4, 7) - split on y
    (2, 3) - split on x
    (5, 4) - split on x
  (9, 6) - split on y
    (8, 1) - split on x
Searching in K-d Trees
Finding the Nearest Neighbor
One of the most common applications of k-d trees is finding the nearest neighbor to a query point. The algorithm works as follows:
- Start with the root node
- Recursively traverse the tree, choosing the appropriate branch based on the query point
- When reaching a leaf, save it as the "current best"
- Backtrack, and for each node:
- Check if the node's point is closer than the current best
- Check if the opposite branch could contain a closer point by examining the distance to the splitting plane
 
Here's an implementation of nearest neighbor search:
import math
def distance(p1, p2):
    """Calculate Euclidean distance between two points"""
    return math.sqrt(sum((a - b) ** 2 for a, b in zip(p1, p2)))
def nearest_neighbor(root, query_point, best=None):
    if root is None:
        return best
    
    # Update best if current point is closer
    if best is None or distance(query_point, root.point) < distance(query_point, best):
        best = root.point
    
    # Current splitting dimension
    dim = root.dim
    
    # Determine which subtree to search first (closer branch)
    if query_point[dim] < root.point[dim]:
        first, second = root.left, root.right
    else:
        first, second = root.right, root.left
    
    # Search the closer branch
    best = nearest_neighbor(first, query_point, best)
    
    # Check if we need to search the other branch
    # by comparing distance to splitting plane with current best distance
    if second is not None:
        dist_to_plane = abs(query_point[dim] - root.point[dim])
        if dist_to_plane < distance(query_point, best):
            # The other branch could contain a closer point
            best = nearest_neighbor(second, query_point, best)
    
    return best
Let's test our nearest neighbor search:
# Example usage
query_point = (6, 5)
nearest = nearest_neighbor(root, query_point)
print(f"Nearest point to {query_point} is {nearest}, with distance {distance(query_point, nearest)}")
Output:
Nearest point to (6, 5) is (5, 4), with distance 1.4142135623730951
Range Search
Another common operation is finding all points within a certain distance or region. The k-d tree allows us to prune branches that can't contain points in our target range:
def range_search(root, query_point, radius):
    """Find all points within a given radius of the query point"""
    result = []
    
    def search(node):
        if node is None:
            return
        
        # Check if current point is within radius
        if distance(query_point, node.point) <= radius:
            result.append(node.point)
        
        # Current splitting dimension
        dim = node.dim
        
        # Check if we need to search left subtree
        if query_point[dim] - radius <= node.point[dim]:
            search(node.left)
        
        # Check if we need to search right subtree
        if query_point[dim] + radius >= node.point[dim]:
            search(node.right)
    
    search(root)
    return result
Usage example:
# Find points within radius 3 of (6, 5)
points_in_range = range_search(root, (6, 5), 3)
print(f"Points within radius 3 of (6, 5): {points_in_range}")
Output:
Points within radius 3 of (6, 5): [(7, 2), (5, 4), (9, 6), (4, 7)]
Time Complexity
- Building a K-d tree: O(n log n) in the average case, where n is the number of points
- Inserting a new point: O(log n) in the average case
- Finding the nearest neighbor: O(log n) in the average case, but can degrade to O(n) in the worst case
- Range search: O(√n + k) in the average case for a balanced tree, where k is the number of reported points
Practical Applications
K-d trees are widely used in various applications, including: