K-d Trees
Introduction
K-d trees (short for k-dimensional trees) are a space-partitioning data structure that organize points in a k-dimensional space. They are particularly useful for applications that involve multidimensional search keys, such as nearest neighbor searches and range searches.
Imagine you have a large number of points in a 2D or 3D space, and you frequently need to find the closest point to a given query point. A naive approach would require checking the distance to every point, which is inefficient for large datasets. K-d trees solve this problem by partitioning the space in a way that allows for much faster searches.
Understanding K-d Trees
What is a K-d Tree?
A k-d tree is a binary tree where:
- Each node represents a point in k-dimensional space
- Each non-leaf node divides the space into two half-spaces along one dimension
- Points to the left of the node have a smaller value in the splitting dimension
- Points to the right have a greater value in the splitting dimension
The splitting dimension typically cycles through all dimensions as we move down the tree. For example, in a 2D space (k=2), the root node might split along the x-axis, its children along the y-axis, its grandchildren along the x-axis again, and so on.
Visual Representation
Let's visualize how a k-d tree partitions a 2D space:
This represents how the tree might look, but let's see how the actual space is partitioned:
Building a K-d Tree
Building a k-d tree involves recursively partitioning the point set along alternating dimensions:
- Choose a splitting dimension (typically cycling through dimensions)
- Find the median point along that dimension
- Create a node with that median point
- Recursively build the left subtree with points on the "less than" side
- Recursively build the right subtree with points on the "greater than" side
Let's implement a K-d tree in Python for 2D points:
class KdNode:
def __init__(self, point, dim, left=None, right=None):
self.point = point # The point (x, y)
self.dim = dim # The dimension to split on (0 for x, 1 for y)
self.left = left # Left subtree
self.right = right # Right subtree
def build_kdtree(points, depth=0):
if not points:
return None
k = len(points[0]) # Dimensionality of points
dim = depth % k # Current dimension to split on
# Sort points based on the current dimension
points.sort(key=lambda x: x[dim])
# Find median point
median_idx = len(points) // 2
# Create node and recursively build subtrees
return KdNode(
point=points[median_idx],
dim=dim,
left=build_kdtree(points[:median_idx], depth+1),
right=build_kdtree(points[median_idx+1:], depth+1)
)
Let's see this in action with a simple example:
# Example points in 2D space
points = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
# Build the k-d tree
root = build_kdtree(points)
# Print tree structure
def print_tree(node, level=0):
if node is not None:
print(" " * level + f"({node.point[0]}, {node.point[1]}) - split on {'x' if node.dim == 0 else 'y'}")
print_tree(node.left, level + 1)
print_tree(node.right, level + 1)
print_tree(root)
Output:
(7, 2) - split on x
(4, 7) - split on y
(2, 3) - split on x
(5, 4) - split on x
(9, 6) - split on y
(8, 1) - split on x
Searching in K-d Trees
Finding the Nearest Neighbor
One of the most common applications of k-d trees is finding the nearest neighbor to a query point. The algorithm works as follows:
- Start with the root node
- Recursively traverse the tree, choosing the appropriate branch based on the query point
- When reaching a leaf, save it as the "current best"
- Backtrack, and for each node:
- Check if the node's point is closer than the current best
- Check if the opposite branch could contain a closer point by examining the distance to the splitting plane
Here's an implementation of nearest neighbor search:
import math
def distance(p1, p2):
"""Calculate Euclidean distance between two points"""
return math.sqrt(sum((a - b) ** 2 for a, b in zip(p1, p2)))
def nearest_neighbor(root, query_point, best=None):
if root is None:
return best
# Update best if current point is closer
if best is None or distance(query_point, root.point) < distance(query_point, best):
best = root.point
# Current splitting dimension
dim = root.dim
# Determine which subtree to search first (closer branch)
if query_point[dim] < root.point[dim]:
first, second = root.left, root.right
else:
first, second = root.right, root.left
# Search the closer branch
best = nearest_neighbor(first, query_point, best)
# Check if we need to search the other branch
# by comparing distance to splitting plane with current best distance
if second is not None:
dist_to_plane = abs(query_point[dim] - root.point[dim])
if dist_to_plane < distance(query_point, best):
# The other branch could contain a closer point
best = nearest_neighbor(second, query_point, best)
return best
Let's test our nearest neighbor search:
# Example usage
query_point = (6, 5)
nearest = nearest_neighbor(root, query_point)
print(f"Nearest point to {query_point} is {nearest}, with distance {distance(query_point, nearest)}")
Output:
Nearest point to (6, 5) is (5, 4), with distance 1.4142135623730951
Range Search
Another common operation is finding all points within a certain distance or region. The k-d tree allows us to prune branches that can't contain points in our target range:
def range_search(root, query_point, radius):
"""Find all points within a given radius of the query point"""
result = []
def search(node):
if node is None:
return
# Check if current point is within radius
if distance(query_point, node.point) <= radius:
result.append(node.point)
# Current splitting dimension
dim = node.dim
# Check if we need to search left subtree
if query_point[dim] - radius <= node.point[dim]:
search(node.left)
# Check if we need to search right subtree
if query_point[dim] + radius >= node.point[dim]:
search(node.right)
search(root)
return result
Usage example:
# Find points within radius 3 of (6, 5)
points_in_range = range_search(root, (6, 5), 3)
print(f"Points within radius 3 of (6, 5): {points_in_range}")
Output:
Points within radius 3 of (6, 5): [(7, 2), (5, 4), (9, 6), (4, 7)]
Time Complexity
- Building a K-d tree: O(n log n) in the average case, where n is the number of points
- Inserting a new point: O(log n) in the average case
- Finding the nearest neighbor: O(log n) in the average case, but can degrade to O(n) in the worst case
- Range search: O(√n + k) in the average case for a balanced tree, where k is the number of reported points
Practical Applications
K-d trees are widely used in various applications, including: