k-d Tree — Practice Tasks¶
All tasks must be solved in Go, Java, and Python. Always test against a brute-force O(n) reference on random points: the nearest/range/radius results must match exactly.
Beginner Tasks¶
Task 1: Implement a balanced 2D k-d tree from scratch (no libraries). Build by median split with alternating x/y axes, then implement nearest(q) returning the closest stored point and its squared distance.
Go¶
package main
type Point struct{ X, Y float64 }
type Node struct {
P Point
Left, Right *Node
Axis int
}
func build(pts []Point, depth int) *Node {
// implement: median split, axis = depth % 2
return nil
}
func nearest(root *Node, q Point) (Point, float64) {
// implement: descend near, prune far with diff*diff < bestD
return Point{}, 0
}
func main() {
// build, then nearest({6,5}) on the 7-point example → (5,4), 2
}
Java¶
public class Task1 {
static final class Node { double x, y; Node left, right; int axis; }
static Node build(double[][] pts, int lo, int hi, int depth) {
// implement
return null;
}
static double[] nearest(Node root, double qx, double qy) {
// implement
return null;
}
public static void main(String[] args) {
// build, then nearest(6,5) → (5,4)
}
}
Python¶
class Node:
__slots__ = ("p", "left", "right", "axis")
def build(pts, depth=0):
# implement: median split, axis = depth % 2
pass
def nearest(root, q):
# implement: descend near, prune far with diff*diff < best
pass
if __name__ == "__main__":
pts = [(7,2),(5,4),(9,6),(2,3),(4,7),(8,1),(9,9)]
# build, then nearest((6,5)) → ((5,4), 2)
- Constraints: correct O(n log n) build, correct pruning. Test with
n=1, duplicates, query equal to a stored point. - Expected Output:
nearest((6,5))returns(5,4)with squared distance2. - Evaluation: correctness vs brute force, pruning present, alternating axes.
Task 2: Add an insert(p) that descends and attaches p as a leaf (accepting that the tree may unbalance). Verify nearest-neighbor still returns correct results after several inserts. Provide starter code in all 3 languages. - Constraints: must preserve the alternating-axis BST invariant.
Task 3: Implement printTree that draws the tree with each node's point and split axis ([x]/[y]), indented by depth. Use it to visually confirm your build is balanced.
Task 4: Implement squared-distance comparison correctly: write sqDist and confirm nearest never calls sqrt. Add an optional nearestDistance(q) that returns the true (sqrt) distance only for the final answer.
Task 5: Write a brute-force nearestBrute(pts, q) and a randomized test harness that generates 1000 random point sets and 100 random queries each, asserting nearest == nearestBrute. This harness is reused by all later tasks.
Intermediate Tasks¶
Task 6: Implement k-nearest-neighbors knn(q, k) using a bounded max-heap of size k. Prune the far subtree only when the heap is full and diff*diff >= heap.top.dist. Return the k points sorted by ascending distance. Provide starter code in all 3 languages. - Constraints: expected O(k + log n); correct +∞ bound until the heap fills.
Task 7: Implement orthogonal range search rangeSearch(lo, hi) returning all points inside the inclusive axis-aligned box. Prune a child when the box does not cross that side of the split line. Verify against brute force.
Task 8: Implement radius search radiusSearch(q, r) returning all points within Euclidean distance r of q. Prune the far side when diff*diff > r*r. Test with r=0 (only exact matches) and r covering the whole cloud.
Task 9: Replace the per-level sort with quickselect (nth_element) median finding so build is O(n log n) instead of O(n log² n). Benchmark both builds on n = 10⁵ points and report the speedup.
Task 10: Generalize your tree from 2D to k-D (arbitrary dimension). Use axis = depth mod k and loop over all coordinates in sqDist and the split comparison. Confirm correctness in 3D against brute force.
Advanced Tasks¶
Task 11: Implement deletion correctly: to delete a node, find the minimum on its split axis in the appropriate subtree and replace (the k-d analogue of BST successor splicing). Verify the tree remains a valid k-d tree after a sequence of deletes. Provide starter code in all 3 languages. - Constraints: must not naively splice; use a recursive find-min-on-axis helper.
Task 12: Implement a double-buffered dynamic k-d tree: serve queries from the "live" tree while a background goroutine/thread rebuilds a fresh tree from the updated point set, then atomically swap. Demonstrate that no query ever observes a partially-built tree.
Task 13: Demonstrate the curse of dimensionality empirically. For dimensions d ∈ {2, 4, 8, 16, 32}, build a k-d tree on n = 10⁵ uniform random points, instrument nodesVisited per NN query, and plot/print the fraction of nodes visited vs d. Confirm it rises toward 1.0 as d grows.
Task 14: Implement the widest-spread split rule: at each node choose the axis with the largest coordinate range among the subtree's points (instead of cyclic depth mod k). Compare nodesVisited against the cyclic rule on clustered (non-uniform) data.
Task 15: Build a 2D nearest-café service: load 100k random (lat, lon) points, convert them to 3D Cartesian on the unit sphere, build a 3D k-d tree, and answer "nearest café to this location" queries. Verify that 3D-Cartesian NN matches great-circle (Haversine) nearest on a sample of queries.
Benchmark Task¶
Compare nearest-neighbor query performance across all 3 languages and across dimensions, to feel both the O(log n) win in low dim and the curse-of-dimensionality degradation.
Go¶
package main
import (
"fmt"
"math/rand"
"time"
)
func main() {
for _, d := range []int{2, 4, 8, 16, 32} {
n := 100000
pts := make([]Point, n) // adapt Point to d dims (use []float64)
_ = pts
_ = rand.Float64
start := time.Now()
// build tree of dimension d, run 10000 NN queries
elapsed := time.Since(start)
fmt.Printf("d=%2d n=%d: %.2f us/query\n", d, n,
float64(elapsed.Microseconds())/10000)
}
}
Java¶
import java.util.Random;
public class Benchmark {
public static void main(String[] args) {
int[] dims = {2, 4, 8, 16, 32};
int n = 100000;
Random rng = new Random(42);
for (int d : dims) {
// build tree of dimension d, run 10000 NN queries
long start = System.nanoTime();
// ...
long elapsed = System.nanoTime() - start;
System.out.printf("d=%2d n=%d: %.2f us/query%n", d, n,
elapsed / 10000.0 / 1000.0);
}
}
}
Python¶
import random
import time
def benchmark():
for d in (2, 4, 8, 16, 32):
n = 100_000
pts = [tuple(random.random() for _ in range(d)) for _ in range(n)]
# build tree of dimension d
q = tuple(random.random() for _ in range(d))
start = time.perf_counter()
for _ in range(10_000):
pass # nearest(tree, q)
elapsed = time.perf_counter() - start
print(f"d={d:2d} n={n}: {elapsed / 10_000 * 1e6:.2f} us/query")
if __name__ == "__main__":
benchmark()
Expected observation: query time stays near-constant (O(log n)) at d = 2–4, then climbs sharply toward the brute-force O(n) cost by d = 32 — the curse of dimensionality made measurable. Cross-reference professional.md §6 for the theory.
In this topic
- interview
- tasks