Skip to content

Centroid Decomposition — Practice Tasks

One-line summary: Fifteen graded tasks (5 beginner, 5 intermediate, 5 advanced) plus a benchmark, each with statement, constraints, hints, and reference solutions in Go, Java, and Python — building from "find a centroid" up to dynamic nearest-marked-node and large-scale performance.


Table of Contents

  1. Beginner Tasks (5)
  2. Intermediate Tasks (5)
  3. Advanced Tasks (5)
  4. Benchmark Task

Beginner Tasks (5)

B1 — Find a single centroid

Statement. Given a tree with N vertices (0-indexed) and N-1 edges, return any centroid: a vertex whose removal leaves every component with ≤ ⌊N/2⌋ vertices.

Constraints. 1 ≤ N ≤ 10⁵. Tree is connected.

Hints. - Root anywhere; compute subtree sizes with one DFS. - Descend toward the child whose subtree size > N/2; stop when none does.

Go

package main

import "fmt"

var adj [][]int
var size []int
var N int

func computeSize(u, p int) int {
    size[u] = 1
    for _, v := range adj[u] {
        if v != p {
            size[u] += computeSize(v, u)
        }
    }
    return size[u]
}

func findCentroid(u, p int) int {
    for _, v := range adj[u] {
        if v != p && size[v] > N/2 {
            return findCentroid(v, u)
        }
    }
    return u
}

func main() {
    N = 7
    adj = make([][]int, N)
    size = make([]int, N)
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {5, 6}} {
        adj[e[0]] = append(adj[e[0]], e[1])
        adj[e[1]] = append(adj[e[1]], e[0])
    }
    computeSize(0, -1)
    fmt.Println("centroid:", findCentroid(0, -1)) // 1
}

Java

import java.util.*;

public class FindCentroid {
    static List<List<Integer>> adj;
    static int[] size;
    static int N;

    static int computeSize(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p) size[u] += computeSize(v, u);
        return size[u];
    }

    static int findCentroid(int u, int p) {
        for (int v : adj.get(u)) if (v != p && size[v] > N / 2) return findCentroid(v, u);
        return u;
    }

    public static void main(String[] args) {
        N = 7;
        adj = new ArrayList<>();
        for (int i = 0; i < N; i++) adj.add(new ArrayList<>());
        int[][] e = {{0,1},{1,2},{1,3},{2,4},{3,5},{5,6}};
        for (int[] x : e) { adj.get(x[0]).add(x[1]); adj.get(x[1]).add(x[0]); }
        size = new int[N];
        computeSize(0, -1);
        System.out.println("centroid: " + findCentroid(0, -1)); // 1
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)

N = 7
adj = [[] for _ in range(N)]
size = [0] * N
for u, v in [(0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (5, 6)]:
    adj[u].append(v)
    adj[v].append(u)


def compute_size(u, p):
    size[u] = 1
    for v in adj[u]:
        if v != p:
            size[u] += compute_size(v, u)
    return size[u]


def find_centroid(u, p):
    for v in adj[u]:
        if v != p and size[v] > N // 2:
            return find_centroid(v, u)
    return u


compute_size(0, -1)
print("centroid:", find_centroid(0, -1))  # 1

B2 — Verify a vertex is a centroid

Statement. Given a tree and a vertex c, return true iff every component of T − c has size ≤ ⌊N/2⌋.

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Root at c. Each child subtree is one component; the "rest" is N − size[c] (which is 0 when rooted at c). - Component sizes are exactly the child-subtree sizes of c when c is the root.

Go

package main

import "fmt"

func isCentroid(adj [][]int, N, c int) bool {
    size := make([]int, N)
    var dfs func(u, p int) int
    dfs = func(u, p int) int {
        size[u] = 1
        for _, v := range adj[u] {
            if v != p {
                size[u] += dfs(v, u)
            }
        }
        return size[u]
    }
    dfs(c, -1)
    for _, v := range adj[c] {
        if size[v] > N/2 {
            return false
        }
    }
    return true
}

func main() {
    N := 7
    adj := make([][]int, N)
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {5, 6}} {
        adj[e[0]] = append(adj[e[0]], e[1])
        adj[e[1]] = append(adj[e[1]], e[0])
    }
    fmt.Println(isCentroid(adj, N, 1)) // true
    fmt.Println(isCentroid(adj, N, 0)) // false
}

Java

import java.util.*;

public class VerifyCentroid {
    static List<List<Integer>> adj;
    static int[] size;
    static int N;

    static int dfs(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p) size[u] += dfs(v, u);
        return size[u];
    }

    static boolean isCentroid(int c) {
        size = new int[N];
        dfs(c, -1);
        for (int v : adj.get(c)) if (size[v] > N / 2) return false;
        return true;
    }

    public static void main(String[] args) {
        N = 7;
        adj = new ArrayList<>();
        for (int i = 0; i < N; i++) adj.add(new ArrayList<>());
        int[][] e = {{0,1},{1,2},{1,3},{2,4},{3,5},{5,6}};
        for (int[] x : e) { adj.get(x[0]).add(x[1]); adj.get(x[1]).add(x[0]); }
        System.out.println(isCentroid(1)); // true
        System.out.println(isCentroid(0)); // false
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


def is_centroid(adj, N, c):
    size = [0] * N

    def dfs(u, p):
        size[u] = 1
        for v in adj[u]:
            if v != p:
                size[u] += dfs(v, u)
        return size[u]

    dfs(c, -1)
    return all(size[v] <= N // 2 for v in adj[c])


N = 7
adj = [[] for _ in range(N)]
for u, v in [(0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (5, 6)]:
    adj[u].append(v)
    adj[v].append(u)
print(is_centroid(adj, N, 1))  # True
print(is_centroid(adj, N, 0))  # False

B3 — Build the centroid tree (return parents)

Statement. Build the centroid tree and return an array cparent[] where cparent[c] is the parent of c in the centroid tree (-1 for the root).

Constraints. 1 ≤ N ≤ 2·10⁵.

Hints. - Reuse computeSize / findCentroid but respect a removed[] flag. - Recompute sizes over the residual tree before each centroid choice.

Go

package main

import "fmt"

type CT struct {
    adj             [][]int
    removed         []bool
    size, cparent   []int
}

func (c *CT) computeSize(u, p int) int {
    c.size[u] = 1
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] {
            c.size[u] += c.computeSize(v, u)
        }
    }
    return c.size[u]
}

func (c *CT) findCentroid(u, p, n int) int {
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] && c.size[v] > n/2 {
            return c.findCentroid(v, u, n)
        }
    }
    return u
}

func (c *CT) decompose(entry, par int) {
    n := c.computeSize(entry, -1)
    ce := c.findCentroid(entry, -1, n)
    c.removed[ce] = true
    c.cparent[ce] = par
    for _, v := range c.adj[ce] {
        if !c.removed[v] {
            c.decompose(v, ce)
        }
    }
}

func main() {
    N := 7
    c := &CT{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N), cparent: make([]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {5, 6}} {
        c.adj[e[0]] = append(c.adj[e[0]], e[1])
        c.adj[e[1]] = append(c.adj[e[1]], e[0])
    }
    c.decompose(0, -1)
    fmt.Println(c.cparent) // root is 1 -> cparent[1] = -1
}

Java

import java.util.*;

public class BuildCentroidTree {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size, cparent;

    BuildCentroidTree(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
        cparent = new int[n];
        Arrays.fill(cparent, -1);
    }

    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }

    int computeSize(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += computeSize(v, u);
        return size[u];
    }

    int findCentroid(int u, int p, int n) {
        for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return findCentroid(v, u, n);
        return u;
    }

    void decompose(int entry, int par) {
        int n = computeSize(entry, -1);
        int c = findCentroid(entry, -1, n);
        removed[c] = true;
        cparent[c] = par;
        for (int v : adj.get(c)) if (!removed[v]) decompose(v, c);
    }

    public static void main(String[] args) {
        BuildCentroidTree t = new BuildCentroidTree(7);
        int[][] e = {{0,1},{1,2},{1,3},{2,4},{3,5},{5,6}};
        for (int[] x : e) t.addEdge(x[0], x[1]);
        t.decompose(0, -1);
        System.out.println(Arrays.toString(t.cparent));
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


class BuildCentroidTree:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.cparent = [-1] * n

    def add_edge(self, u, v):
        self.adj[u].append(v)
        self.adj[v].append(u)

    def compute_size(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.compute_size(v, u)
        return self.size[u]

    def find_centroid(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.find_centroid(v, u, n)
        return u

    def decompose(self, entry, par):
        n = self.compute_size(entry, -1)
        c = self.find_centroid(entry, -1, n)
        self.removed[c] = True
        self.cparent[c] = par
        for v in self.adj[c]:
            if not self.removed[v]:
                self.decompose(v, c)


t = BuildCentroidTree(7)
for u, v in [(0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (5, 6)]:
    t.add_edge(u, v)
t.decompose(0, -1)
print(t.cparent)

B4 — Centroid-tree depth of each vertex

Statement. Output level[v] = depth of v in the centroid tree (root = 0). Verify the maximum level is ≤ ⌊log₂ N⌋ + 1.

Constraints. 1 ≤ N ≤ 2·10⁵.

Hints. - During decompose, pass the current depth; set level[centroid] = depth. - Or compute from cparent[] afterward.

Go

package main

import "fmt"

type CT struct {
    adj           [][]int
    removed       []bool
    size, level   []int
}

func (c *CT) computeSize(u, p int) int {
    c.size[u] = 1
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] {
            c.size[u] += c.computeSize(v, u)
        }
    }
    return c.size[u]
}

func (c *CT) findCentroid(u, p, n int) int {
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] && c.size[v] > n/2 {
            return c.findCentroid(v, u, n)
        }
    }
    return u
}

func (c *CT) decompose(entry, depth int) {
    n := c.computeSize(entry, -1)
    ce := c.findCentroid(entry, -1, n)
    c.removed[ce] = true
    c.level[ce] = depth
    for _, v := range c.adj[ce] {
        if !c.removed[v] {
            c.decompose(v, depth+1)
        }
    }
}

func main() {
    N := 7
    c := &CT{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N), level: make([]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {5, 6}} {
        c.adj[e[0]] = append(c.adj[e[0]], e[1])
        c.adj[e[1]] = append(c.adj[e[1]], e[0])
    }
    c.decompose(0, 0)
    fmt.Println(c.level)
}

Java

import java.util.*;

public class CentroidLevels {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size, level;

    CentroidLevels(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
        level = new int[n];
    }

    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }

    int computeSize(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += computeSize(v, u);
        return size[u];
    }

    int findCentroid(int u, int p, int n) {
        for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return findCentroid(v, u, n);
        return u;
    }

    void decompose(int entry, int depth) {
        int n = computeSize(entry, -1);
        int c = findCentroid(entry, -1, n);
        removed[c] = true;
        level[c] = depth;
        for (int v : adj.get(c)) if (!removed[v]) decompose(v, depth + 1);
    }

    public static void main(String[] args) {
        CentroidLevels t = new CentroidLevels(7);
        int[][] e = {{0,1},{1,2},{1,3},{2,4},{3,5},{5,6}};
        for (int[] x : e) t.addEdge(x[0], x[1]);
        t.decompose(0, 0);
        System.out.println(Arrays.toString(t.level));
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


class CentroidLevels:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.level = [0] * n

    def add_edge(self, u, v):
        self.adj[u].append(v)
        self.adj[v].append(u)

    def compute_size(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.compute_size(v, u)
        return self.size[u]

    def find_centroid(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.find_centroid(v, u, n)
        return u

    def decompose(self, entry, depth):
        n = self.compute_size(entry, -1)
        c = self.find_centroid(entry, -1, n)
        self.removed[c] = True
        self.level[c] = depth
        for v in self.adj[c]:
            if not self.removed[v]:
                self.decompose(v, depth + 1)


t = CentroidLevels(7)
for u, v in [(0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (5, 6)]:
    t.add_edge(u, v)
t.decompose(0, 0)
print(t.level)

B5 — Count both centroids when there are two

Statement. Return all centroids of the tree (one or two vertices).

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Find one centroid c. If a child component of c equals exactly N/2, the neighbor in that direction is the second centroid. - Equivalently: a vertex is a centroid iff max(child sizes, N − size_subtree) ≤ ⌊N/2⌋; collect all such.

Go

package main

import "fmt"

func centroids(adj [][]int, N int) []int {
    size := make([]int, N)
    var dfs func(u, p int) int
    dfs = func(u, p int) int {
        size[u] = 1
        for _, v := range adj[u] {
            if v != p {
                size[u] += dfs(v, u)
            }
        }
        return size[u]
    }
    dfs(0, -1)
    var res []int
    var check func(u, p int)
    check = func(u, p int) {
        w := N - size[u]
        for _, v := range adj[u] {
            if v != p {
                if size[v] > w {
                    w = size[v]
                }
            }
        }
        if w <= N/2 {
            res = append(res, u)
        }
        for _, v := range adj[u] {
            if v != p {
                check(v, u)
            }
        }
    }
    check(0, -1)
    return res
}

func main() {
    // path 0-1-2-3 has two centroids: 1 and 2
    N := 4
    adj := make([][]int, N)
    for _, e := range [][2]int{{0, 1}, {1, 2}, {2, 3}} {
        adj[e[0]] = append(adj[e[0]], e[1])
        adj[e[1]] = append(adj[e[1]], e[0])
    }
    fmt.Println(centroids(adj, N)) // [1 2]
}

Java

import java.util.*;

public class TwoCentroids {
    static List<List<Integer>> adj;
    static int[] size;
    static int N;
    static List<Integer> res = new ArrayList<>();

    static int dfs(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p) size[u] += dfs(v, u);
        return size[u];
    }

    static void check(int u, int p) {
        int w = N - size[u];
        for (int v : adj.get(u)) if (v != p) w = Math.max(w, size[v]);
        if (w <= N / 2) res.add(u);
        for (int v : adj.get(u)) if (v != p) check(v, u);
    }

    public static void main(String[] args) {
        N = 4;
        adj = new ArrayList<>();
        for (int i = 0; i < N; i++) adj.add(new ArrayList<>());
        int[][] e = {{0,1},{1,2},{2,3}};
        for (int[] x : e) { adj.get(x[0]).add(x[1]); adj.get(x[1]).add(x[0]); }
        size = new int[N];
        dfs(0, -1);
        check(0, -1);
        System.out.println(res); // [1, 2]
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


def centroids(adj, N):
    size = [0] * N

    def dfs(u, p):
        size[u] = 1
        for v in adj[u]:
            if v != p:
                size[u] += dfs(v, u)
        return size[u]

    dfs(0, -1)
    res = []

    def check(u, p):
        w = N - size[u]
        for v in adj[u]:
            if v != p:
                w = max(w, size[v])
        if w <= N // 2:
            res.append(u)
        for v in adj[u]:
            if v != p:
                check(v, u)

    check(0, -1)
    return res


N = 4
adj = [[] for _ in range(N)]
for u, v in [(0, 1), (1, 2), (2, 3)]:
    adj[u].append(v)
    adj[v].append(u)
print(centroids(adj, N))  # [1, 2]

Intermediate Tasks (5)

I1 — Count pairs at distance exactly K

Statement. Count unordered pairs (u, v) with exactly K edges between them.

Constraints. 1 ≤ N ≤ 2·10⁵, 1 ≤ K ≤ N.

Hints. - Per centroid, register distances branch by branch; query cnt[K − d] before registering each branch. - Reset only touched buckets to keep it O(N log N).

This is Challenge 2 in interview.md; see the full Go/Java/Python solutions there. Expected answer for the test tree (0-1-2, 1-3-4, K=2) is 4 (pairs {0,2}, {0,3}, {2,3}, {2,4} all have exactly 2 edges between them) — verify against brute force.


I2 — Count pairs at distance ≤ K (weighted)

Statement. Weighted tree; count unordered pairs with path weight ≤ K.

Constraints. 1 ≤ N ≤ 10⁵, weights ≥ 0, K ≤ 10⁹.

Hints. - Sort each centroid's distance list, count pairs ≤ K with two pointers, subtract per-branch over-count.

This is the Code Example in middle.md (unweighted) and Challenge 1 in interview.md (weighted). Reuse those reference solutions.


I3 — Distance to nearest leaf via centroid ancestors (static)

Statement. A set S of "special" vertices is fixed. For every vertex x, output min_{s∈S} dist(x, s).

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Per centroid, store the min distance to any special vertex in its component. - For each x, answer = min over ancestors c of (dist(x,c) + minSpecial[c]).

Go

package main

import "fmt"

const INF = 1 << 30

type S struct {
    adj           [][]int
    removed       []bool
    size, cpar    []int
    ancDist       [][][2]int // (centroid, dist)
    best          []int
}

func (s *S) cs(u, p int) int {
    s.size[u] = 1
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.size[u] += s.cs(v, u)
        }
    }
    return s.size[u]
}

func (s *S) fc(u, p, n int) int {
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] && s.size[v] > n/2 {
            return s.fc(v, u, n)
        }
    }
    return u
}

func (s *S) rec(u, p, d, c int) {
    s.ancDist[u] = append(s.ancDist[u], [2]int{c, d})
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.rec(v, u, d+1, c)
        }
    }
}

func (s *S) dec(entry int) {
    n := s.cs(entry, -1)
    c := s.fc(entry, -1, n)
    s.removed[c] = true
    s.rec(c, -1, 0, c)
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            s.dec(v)
        }
    }
}

func main() {
    N := 7
    s := &S{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N),
        cpar: make([]int, N), ancDist: make([][][2]int, N), best: make([]int, N)}
    for i := range s.best {
        s.best[i] = INF
    }
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {5, 6}} {
        s.adj[e[0]] = append(s.adj[e[0]], e[1])
        s.adj[e[1]] = append(s.adj[e[1]], e[0])
    }
    s.dec(0)
    special := []int{4, 6}
    for _, x := range special {
        for _, pr := range s.ancDist[x] {
            if pr[1] < s.best[pr[0]] {
                s.best[pr[0]] = pr[1]
            }
        }
    }
    for x := 0; x < N; x++ {
        ans := INF
        for _, pr := range s.ancDist[x] {
            if s.best[pr[0]] < INF && pr[1]+s.best[pr[0]] < ans {
                ans = pr[1] + s.best[pr[0]]
            }
        }
        fmt.Printf("dist(%d, S) = %d\n", x, ans)
    }
}

Java

import java.util.*;

public class NearestSpecial {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size, best;
    List<int[]>[] ancDist;
    static final int INF = 1 << 30;

    @SuppressWarnings("unchecked")
    NearestSpecial(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
        best = new int[n];
        Arrays.fill(best, INF);
        ancDist = new List[n];
        for (int i = 0; i < n; i++) ancDist[i] = new ArrayList<>();
    }

    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }

    int cs(int u, int p) {
        size[u] = 1;
        for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += cs(v, u);
        return size[u];
    }

    int fc(int u, int p, int n) {
        for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return fc(v, u, n);
        return u;
    }

    void rec(int u, int p, int d, int c) {
        ancDist[u].add(new int[]{c, d});
        for (int v : adj.get(u)) if (v != p && !removed[v]) rec(v, u, d + 1, c);
    }

    void dec(int entry) {
        int n = cs(entry, -1);
        int c = fc(entry, -1, n);
        removed[c] = true;
        rec(c, -1, 0, c);
        for (int v : adj.get(c)) if (!removed[v]) dec(v);
    }

    public static void main(String[] args) {
        NearestSpecial s = new NearestSpecial(7);
        int[][] e = {{0,1},{1,2},{1,3},{2,4},{3,5},{5,6}};
        for (int[] x : e) s.addEdge(x[0], x[1]);
        s.dec(0);
        for (int x : new int[]{4, 6})
            for (int[] pr : s.ancDist[x]) s.best[pr[0]] = Math.min(s.best[pr[0]], pr[1]);
        for (int x = 0; x < 7; x++) {
            int ans = INF;
            for (int[] pr : s.ancDist[x])
                if (s.best[pr[0]] < INF) ans = Math.min(ans, pr[1] + s.best[pr[0]]);
            System.out.println("dist(" + x + ", S) = " + ans);
        }
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)
INF = 1 << 30


class NearestSpecial:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.best = [INF] * n
        self.anc = [[] for _ in range(n)]

    def add_edge(self, u, v):
        self.adj[u].append(v)
        self.adj[v].append(u)

    def cs(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def rec(self, u, p, d, c):
        self.anc[u].append((c, d))
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.rec(v, u, d + 1, c)

    def dec(self, entry):
        n = self.cs(entry, -1)
        c = self.fc(entry, -1, n)
        self.removed[c] = True
        self.rec(c, -1, 0, c)
        for v in self.adj[c]:
            if not self.removed[v]:
                self.dec(v)


s = NearestSpecial(7)
for u, v in [(0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (5, 6)]:
    s.add_edge(u, v)
s.dec(0)
for x in (4, 6):
    for c, d in s.anc[x]:
        s.best[c] = min(s.best[c], d)
for x in range(7):
    ans = INF
    for c, d in s.anc[x]:
        if s.best[c] < INF:
            ans = min(ans, d + s.best[c])
    print(f"dist({x}, S) = {ans}")

I4 — Number of paths with length in range [L, R]

Statement. Count unordered pairs whose path length (edges) is in [L, R].

Constraints. 1 ≤ N ≤ 10⁵, 0 ≤ L ≤ R ≤ N.

Hints. - countLeq(R) − countLeq(L − 1) using the I2 machinery. - Or per centroid, count distances in the complementary-range with sorted arrays + two binary searches.

Go

package main

import (
    "fmt"
    "sort"
)

type Sol struct {
    adj      [][]int
    removed  []bool
    size     []int
    ans      int64
}

func (s *Sol) cs(u, p int) int {
    s.size[u] = 1
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.size[u] += s.cs(v, u)
        }
    }
    return s.size[u]
}
func (s *Sol) fc(u, p, n int) int {
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] && s.size[v] > n/2 {
            return s.fc(v, u, n)
        }
    }
    return u
}
func (s *Sol) gather(u, p, d int, out *[]int) {
    *out = append(*out, d)
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.gather(v, u, d+1, out)
        }
    }
}
func countLeq(d []int, K int) int64 {
    if K < 0 {
        return 0
    }
    sort.Ints(d)
    var c int64
    lo, hi := 0, len(d)-1
    for lo < hi {
        if d[lo]+d[hi] <= K {
            c += int64(hi - lo)
            lo++
        } else {
            hi--
        }
    }
    return c
}
func (s *Sol) contrib(all []int, K int) int64 { return countLeq(all, K) }

func (s *Sol) dec(entry, L, R int) {
    n := s.cs(entry, -1)
    c := s.fc(entry, -1, n)
    s.removed[c] = true
    all := []int{0}
    branches := [][]int{}
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            var br []int
            s.gather(v, c, 1, &br)
            all = append(all, br...)
            branches = append(branches, br)
        }
    }
    add := func(d []int) int64 { return countLeq(append([]int{}, d...), R) - countLeq(append([]int{}, d...), L-1) }
    s.ans += add(all)
    for _, br := range branches {
        s.ans -= add(br)
    }
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            s.dec(v, L, R)
        }
    }
}

func main() {
    N := 5
    s := &Sol{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {3, 4}} {
        s.adj[e[0]] = append(s.adj[e[0]], e[1])
        s.adj[e[1]] = append(s.adj[e[1]], e[0])
    }
    s.dec(0, 2, 3)
    fmt.Println("pairs with length in [2,3]:", s.ans)
}

Java

import java.util.*;

public class PathsInRange {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size;
    long ans = 0;

    PathsInRange(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
    }
    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }
    int cs(int u, int p) { size[u] = 1; for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += cs(v, u); return size[u]; }
    int fc(int u, int p, int n) { for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return fc(v, u, n); return u; }
    void gather(int u, int p, int d, List<Integer> out) { out.add(d); for (int v : adj.get(u)) if (v != p && !removed[v]) gather(v, u, d + 1, out); }

    long countLeq(List<Integer> src, int K) {
        if (K < 0) return 0;
        List<Integer> d = new ArrayList<>(src);
        Collections.sort(d);
        long c = 0; int lo = 0, hi = d.size() - 1;
        while (lo < hi) { if (d.get(lo) + d.get(hi) <= K) { c += hi - lo; lo++; } else hi--; }
        return c;
    }

    void dec(int entry, int L, int R) {
        int n = cs(entry, -1);
        int c = fc(entry, -1, n);
        removed[c] = true;
        List<Integer> all = new ArrayList<>(); all.add(0);
        List<List<Integer>> branches = new ArrayList<>();
        for (int v : adj.get(c)) if (!removed[v]) {
            List<Integer> br = new ArrayList<>();
            gather(v, c, 1, br);
            all.addAll(br); branches.add(br);
        }
        ans += countLeq(all, R) - countLeq(all, L - 1);
        for (List<Integer> br : branches) ans -= countLeq(br, R) - countLeq(br, L - 1);
        for (int v : adj.get(c)) if (!removed[v]) dec(v, L, R);
    }

    public static void main(String[] args) {
        PathsInRange s = new PathsInRange(5);
        int[][] e = {{0,1},{1,2},{1,3},{3,4}};
        for (int[] x : e) s.addEdge(x[0], x[1]);
        s.dec(0, 2, 3);
        System.out.println("pairs with length in [2,3]: " + s.ans);
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


class PathsInRange:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.ans = 0

    def add_edge(self, u, v):
        self.adj[u].append(v); self.adj[v].append(u)

    def cs(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def gather(self, u, p, d, out):
        out.append(d)
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.gather(v, u, d + 1, out)

    @staticmethod
    def count_leq(src, K):
        if K < 0:
            return 0
        d = sorted(src)
        c, lo, hi = 0, 0, len(d) - 1
        while lo < hi:
            if d[lo] + d[hi] <= K:
                c += hi - lo
                lo += 1
            else:
                hi -= 1
        return c

    def dec(self, entry, L, R):
        n = self.cs(entry, -1)
        c = self.fc(entry, -1, n)
        self.removed[c] = True
        all_d = [0]
        branches = []
        for v in self.adj[c]:
            if not self.removed[v]:
                br = []
                self.gather(v, c, 1, br)
                all_d.extend(br)
                branches.append(br)
        self.ans += self.count_leq(all_d, R) - self.count_leq(all_d, L - 1)
        for br in branches:
            self.ans -= self.count_leq(br, R) - self.count_leq(br, L - 1)
        for v in self.adj[c]:
            if not self.removed[v]:
                self.dec(v, L, R)


s = PathsInRange(5)
for u, v in [(0, 1), (1, 2), (1, 3), (3, 4)]:
    s.add_edge(u, v)
s.dec(0, 2, 3)
print("pairs with length in [2,3]:", s.ans)

I5 — Validate decomposition against brute force

Statement. Write a tester: generate random trees up to N = 200, run your centroid decomposition pair-distance counter, and compare against an O(N²) BFS-from-every-node baseline.

Constraints. Random trees, many seeds.

Hints. - BFS from each vertex gives all distances; tally the target counts. - Any mismatch points to stale sizes or double-counting.

Go

package main

import (
    "fmt"
    "math/rand"
)

func bruteCountLeq(adj [][]int, N, K int) int64 {
    var total int64
    for s := 0; s < N; s++ {
        dist := make([]int, N)
        for i := range dist {
            dist[i] = -1
        }
        dist[s] = 0
        q := []int{s}
        for len(q) > 0 {
            u := q[0]
            q = q[1:]
            for _, v := range adj[u] {
                if dist[v] == -1 {
                    dist[v] = dist[u] + 1
                    q = append(q, v)
                }
            }
        }
        for t := s + 1; t < N; t++ {
            if dist[t] <= K {
                total++
            }
        }
    }
    return total
}

func main() {
    for seed := 0; seed < 5; seed++ {
        r := rand.New(rand.NewSource(int64(seed)))
        N := 2 + r.Intn(50)
        adj := make([][]int, N)
        for i := 1; i < N; i++ {
            p := r.Intn(i)
            adj[i] = append(adj[i], p)
            adj[p] = append(adj[p], i)
        }
        K := r.Intn(N)
        fmt.Printf("seed=%d N=%d K=%d brute=%d\n", seed, N, K, bruteCountLeq(adj, N, K))
        // Compare against your centroid-decomposition counter here.
    }
}

Java

import java.util.*;

public class BruteTester {
    static long bruteCountLeq(List<List<Integer>> adj, int N, int K) {
        long total = 0;
        for (int s = 0; s < N; s++) {
            int[] dist = new int[N];
            Arrays.fill(dist, -1);
            dist[s] = 0;
            ArrayDeque<Integer> q = new ArrayDeque<>();
            q.add(s);
            while (!q.isEmpty()) {
                int u = q.poll();
                for (int v : adj.get(u)) if (dist[v] == -1) { dist[v] = dist[u] + 1; q.add(v); }
            }
            for (int t = s + 1; t < N; t++) if (dist[t] <= K) total++;
        }
        return total;
    }

    public static void main(String[] args) {
        for (int seed = 0; seed < 5; seed++) {
            Random r = new Random(seed);
            int N = 2 + r.nextInt(50);
            List<List<Integer>> adj = new ArrayList<>();
            for (int i = 0; i < N; i++) adj.add(new ArrayList<>());
            for (int i = 1; i < N; i++) {
                int p = r.nextInt(i);
                adj.get(i).add(p); adj.get(p).add(i);
            }
            int K = r.nextInt(N);
            System.out.printf("seed=%d N=%d K=%d brute=%d%n", seed, N, K, bruteCountLeq(adj, N, K));
        }
    }
}

Python

import random
from collections import deque


def brute_count_leq(adj, N, K):
    total = 0
    for s in range(N):
        dist = [-1] * N
        dist[s] = 0
        q = deque([s])
        while q:
            u = q.popleft()
            for v in adj[u]:
                if dist[v] == -1:
                    dist[v] = dist[u] + 1
                    q.append(v)
        total += sum(1 for t in range(s + 1, N) if dist[t] <= K)
    return total


for seed in range(5):
    random.seed(seed)
    N = random.randint(2, 50)
    adj = [[] for _ in range(N)]
    for i in range(1, N):
        p = random.randint(0, i - 1)
        adj[i].append(p)
        adj[p].append(i)
    K = random.randint(0, N - 1)
    print(f"seed={seed} N={N} K={K} brute={brute_count_leq(adj, N, K)}")
    # Compare against your centroid-decomposition counter here.

Advanced Tasks (5)

A1 — Dynamic nearest marked node

Statement. Support mark(x), unmark(x), query(x) = distance to the nearest currently-marked vertex.

Constraints. 1 ≤ N, Q ≤ 10⁵.

Hints. - Precompute dist(x, ancestor); per centroid keep a min-structure of marked distances; query over O(log N) ancestors.

This is the full reference implementation in senior.md §7 (Go/Java/Python). Use a balanced multiset / indexed heap for O(log N) min.


A2 — Count "good" paths (color constraint)

Statement. Each vertex has a color. Count unordered pairs (u, v) such that the path u…v contains at most one vertex of color RED.

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Per centroid, for each vertex track (distance, redCountOnLegToCentroid). - A path is good iff the two legs' red counts sum (minus double-counting the centroid if red) ≤ 1. Bucket distances by red-count and combine; subtract same-branch over-counts.

Go

package main

import "fmt"

type Sol struct {
    adj     [][]int
    color   []int // 1 = RED
    removed []bool
    size    []int
    ans     int64
}

func (s *Sol) cs(u, p int) int {
    s.size[u] = 1
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.size[u] += s.cs(v, u)
        }
    }
    return s.size[u]
}
func (s *Sol) fc(u, p, n int) int {
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] && s.size[v] > n/2 {
            return s.fc(v, u, n)
        }
    }
    return u
}

// gather red-count on the leg from centroid to u (inclusive of u, exclusive of centroid handled by caller)
func (s *Sol) gather(u, p, red int, out *[]int) {
    r := red + s.color[u]
    *out = append(*out, r)
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.gather(v, u, r, out)
        }
    }
}

// count pairs (a in legs, b in legs) with redA + redB <= 1 ; given a centroid red flag cr
func combine(reds []int, cr int) int64 {
    // reds are red-counts of legs (each 0 or more). Path total reds = redA + redB + cr.
    // want redA + redB + cr <= 1  => redA + redB <= 1 - cr.
    limit := 1 - cr
    if limit < 0 {
        return 0
    }
    // count pairs i<j with reds[i]+reds[j] <= limit
    var cnt0, cntPos int64
    for _, r := range reds {
        if r == 0 {
            cnt0++
        } else if r == 1 {
            cntPos++
        }
    }
    var c int64
    if limit >= 0 {
        c += cnt0 * (cnt0 - 1) / 2 // 0+0 <= limit always when limit>=0
    }
    if limit >= 1 {
        c += cnt0 * cntPos // 0+1
    }
    return c
}

func (s *Sol) dec(entry int) {
    n := s.cs(entry, -1)
    c := s.fc(entry, -1, n)
    s.removed[c] = true
    cr := s.color[c]
    all := []int{0} // centroid itself, leg red 0 (centroid's own color tracked via cr)
    branches := [][]int{}
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            var br []int
            s.gather(v, c, 0, &br)
            all = append(all, br...)
            branches = append(branches, br)
        }
    }
    s.ans += combine(all, cr)
    for _, br := range branches {
        s.ans -= combine(br, cr)
    }
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            s.dec(v)
        }
    }
}

func main() {
    N := 5
    s := &Sol{adj: make([][]int, N), color: []int{0, 1, 0, 0, 0}, removed: make([]bool, N), size: make([]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {3, 4}} {
        s.adj[e[0]] = append(s.adj[e[0]], e[1])
        s.adj[e[1]] = append(s.adj[e[1]], e[0])
    }
    s.dec(0)
    fmt.Println("good pairs (<=1 red):", s.ans)
}

Java

import java.util.*;

public class GoodPaths {
    List<List<Integer>> adj;
    int[] color, size;
    boolean[] removed;
    long ans = 0;

    GoodPaths(int n, int[] color) {
        this.color = color;
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
    }
    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }
    int cs(int u, int p) { size[u] = 1; for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += cs(v, u); return size[u]; }
    int fc(int u, int p, int n) { for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return fc(v, u, n); return u; }
    void gather(int u, int p, int red, List<Integer> out) {
        int r = red + color[u];
        out.add(r);
        for (int v : adj.get(u)) if (v != p && !removed[v]) gather(v, u, r, out);
    }
    long combine(List<Integer> reds, int cr) {
        int limit = 1 - cr;
        if (limit < 0) return 0;
        long c0 = 0, c1 = 0;
        for (int r : reds) { if (r == 0) c0++; else if (r == 1) c1++; }
        long c = c0 * (c0 - 1) / 2;
        if (limit >= 1) c += c0 * c1;
        return c;
    }
    void dec(int entry) {
        int n = cs(entry, -1);
        int c = fc(entry, -1, n);
        removed[c] = true;
        int cr = color[c];
        List<Integer> all = new ArrayList<>(); all.add(0);
        List<List<Integer>> branches = new ArrayList<>();
        for (int v : adj.get(c)) if (!removed[v]) {
            List<Integer> br = new ArrayList<>();
            gather(v, c, 0, br);
            all.addAll(br); branches.add(br);
        }
        ans += combine(all, cr);
        for (List<Integer> br : branches) ans -= combine(br, cr);
        for (int v : adj.get(c)) if (!removed[v]) dec(v);
    }

    public static void main(String[] args) {
        GoodPaths s = new GoodPaths(5, new int[]{0,1,0,0,0});
        int[][] e = {{0,1},{1,2},{1,3},{3,4}};
        for (int[] x : e) s.addEdge(x[0], x[1]);
        s.dec(0);
        System.out.println("good pairs (<=1 red): " + s.ans);
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


class GoodPaths:
    def __init__(self, n, color):
        self.adj = [[] for _ in range(n)]
        self.color = color
        self.removed = [False] * n
        self.size = [0] * n
        self.ans = 0

    def add_edge(self, u, v):
        self.adj[u].append(v); self.adj[v].append(u)

    def cs(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def gather(self, u, p, red, out):
        r = red + self.color[u]
        out.append(r)
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.gather(v, u, r, out)

    @staticmethod
    def combine(reds, cr):
        limit = 1 - cr
        if limit < 0:
            return 0
        c0 = sum(1 for r in reds if r == 0)
        c1 = sum(1 for r in reds if r == 1)
        c = c0 * (c0 - 1) // 2
        if limit >= 1:
            c += c0 * c1
        return c

    def dec(self, entry):
        n = self.cs(entry, -1)
        c = self.fc(entry, -1, n)
        self.removed[c] = True
        cr = self.color[c]
        all_r = [0]
        branches = []
        for v in self.adj[c]:
            if not self.removed[v]:
                br = []
                self.gather(v, c, 0, br)
                all_r.extend(br)
                branches.append(br)
        self.ans += self.combine(all_r, cr)
        for br in branches:
            self.ans -= self.combine(br, cr)
        for v in self.adj[c]:
            if not self.removed[v]:
                self.dec(v)


s = GoodPaths(5, [0, 1, 0, 0, 0])
for u, v in [(0, 1), (1, 2), (1, 3), (3, 4)]:
    s.add_edge(u, v)
s.dec(0)
print("good pairs (<=1 red):", s.ans)

A3 — IOI "Race": shortest edge-count path with total weight exactly K

Statement. Weighted tree; find the minimum number of edges on any path whose total weight is exactly K (or -1 if none).

Constraints. 1 ≤ N ≤ 2·10⁵, 1 ≤ K ≤ 10⁶.

Hints. - Per centroid, maintain best[w] = min edges seen so far for weight w (reset by touched keys, size K+1). - For each new vertex at (weight w, edges e) with w ≤ K, candidate answer = e + best[K − w]. Query before registering each branch.

Go

package main

import "fmt"

const INF = 1 << 30

type edge struct{ to, w int }
type Race struct {
    adj     [][]edge
    removed []bool
    size    []int
    K       int
    best    []int // best[w] = min edges for weight w
    touched []int
    ans     int
}

func (r *Race) cs(u, p int) int {
    r.size[u] = 1
    for _, e := range r.adj[u] {
        if e.to != p && !r.removed[e.to] {
            r.size[u] += r.cs(e.to, u)
        }
    }
    return r.size[u]
}
func (r *Race) fc(u, p, n int) int {
    for _, e := range r.adj[u] {
        if e.to != p && !r.removed[e.to] && r.size[e.to] > n/2 {
            return r.fc(e.to, u, n)
        }
    }
    return u
}
func (r *Race) walk(u, p, w, e int, query bool) {
    if w > r.K {
        return
    }
    if query {
        if r.best[r.K-w] < INF {
            if cand := e + r.best[r.K-w]; cand < r.ans {
                r.ans = cand
            }
        }
    } else {
        if e < r.best[w] {
            r.best[w] = e
            r.touched = append(r.touched, w)
        }
    }
    for _, ed := range r.adj[u] {
        if ed.to != p && !r.removed[ed.to] {
            r.walk(ed.to, u, w+ed.w, e+1, query)
        }
    }
}
func (r *Race) dec(entry int) {
    n := r.cs(entry, -1)
    c := r.fc(entry, -1, n)
    r.removed[c] = true
    r.best[0] = 0
    r.touched = append(r.touched, 0)
    for _, ed := range r.adj[c] {
        if !r.removed[ed.to] {
            r.walk(ed.to, c, ed.w, 1, true)
            r.walk(ed.to, c, ed.w, 1, false)
        }
    }
    for _, w := range r.touched {
        r.best[w] = INF
    }
    r.touched = r.touched[:0]
    for _, ed := range r.adj[c] {
        if !r.removed[ed.to] {
            r.dec(ed.to)
        }
    }
}

func main() {
    N, K := 4, 3
    r := &Race{adj: make([][]edge, N), removed: make([]bool, N), size: make([]int, N), K: K, ans: INF}
    r.best = make([]int, K+1)
    for i := range r.best {
        r.best[i] = INF
    }
    add := func(u, v, w int) {
        r.adj[u] = append(r.adj[u], edge{v, w})
        r.adj[v] = append(r.adj[v], edge{u, w})
    }
    add(0, 1, 1); add(1, 2, 2); add(1, 3, 3)
    r.dec(0)
    if r.ans == INF {
        fmt.Println(-1)
    } else {
        fmt.Println("min edges for weight 3:", r.ans) // edge 1-3 has weight 3 -> 1 edge
    }
}

Java

import java.util.*;

public class Race {
    List<int[]>[] adj; // {to, w}
    boolean[] removed;
    int[] size, best;
    int K, ans = Integer.MAX_VALUE;
    List<Integer> touched = new ArrayList<>();

    @SuppressWarnings("unchecked")
    Race(int n, int K) {
        this.K = K;
        adj = new List[n];
        for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
        removed = new boolean[n];
        size = new int[n];
        best = new int[K + 1];
        Arrays.fill(best, Integer.MAX_VALUE);
    }
    void addEdge(int u, int v, int w) { adj[u].add(new int[]{v, w}); adj[v].add(new int[]{u, w}); }
    int cs(int u, int p) { size[u] = 1; for (int[] e : adj[u]) if (e[0] != p && !removed[e[0]]) size[u] += cs(e[0], u); return size[u]; }
    int fc(int u, int p, int n) { for (int[] e : adj[u]) if (e[0] != p && !removed[e[0]] && size[e[0]] > n / 2) return fc(e[0], u, n); return u; }
    void walk(int u, int p, int w, int e, boolean query) {
        if (w > K) return;
        if (query) {
            if (best[K - w] != Integer.MAX_VALUE) ans = Math.min(ans, e + best[K - w]);
        } else {
            if (e < best[w]) { best[w] = e; touched.add(w); }
        }
        for (int[] ed : adj[u]) if (ed[0] != p && !removed[ed[0]]) walk(ed[0], u, w + ed[1], e + 1, query);
    }
    void dec(int entry) {
        int n = cs(entry, -1);
        int c = fc(entry, -1, n);
        removed[c] = true;
        best[0] = 0; touched.add(0);
        for (int[] ed : adj[c]) if (!removed[ed[0]]) { walk(ed[0], c, ed[1], 1, true); walk(ed[0], c, ed[1], 1, false); }
        for (int w : touched) best[w] = Integer.MAX_VALUE;
        touched.clear();
        for (int[] ed : adj[c]) if (!removed[ed[0]]) dec(ed[0]);
    }

    public static void main(String[] args) {
        Race r = new Race(4, 3);
        r.addEdge(0,1,1); r.addEdge(1,2,2); r.addEdge(1,3,3);
        r.dec(0);
        System.out.println(r.ans == Integer.MAX_VALUE ? -1 : "min edges for weight 3: " + r.ans);
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)
INF = 1 << 30


class Race:
    def __init__(self, n, K):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.K = K
        self.best = [INF] * (K + 1)
        self.touched = []
        self.ans = INF

    def add_edge(self, u, v, w):
        self.adj[u].append((v, w))
        self.adj[v].append((u, w))

    def cs(self, u, p):
        self.size[u] = 1
        for v, _ in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v, _ in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def walk(self, u, p, w, e, query):
        if w > self.K:
            return
        if query:
            if self.best[self.K - w] < INF:
                self.ans = min(self.ans, e + self.best[self.K - w])
        else:
            if e < self.best[w]:
                self.best[w] = e
                self.touched.append(w)
        for v, ww in self.adj[u]:
            if v != p and not self.removed[v]:
                self.walk(v, u, w + ww, e + 1, query)

    def dec(self, entry):
        n = self.cs(entry, -1)
        c = self.fc(entry, -1, n)
        self.removed[c] = True
        self.best[0] = 0
        self.touched.append(0)
        for v, w in self.adj[c]:
            if not self.removed[v]:
                self.walk(v, c, w, 1, True)
                self.walk(v, c, w, 1, False)
        for w in self.touched:
            self.best[w] = INF
        self.touched.clear()
        for v, _ in self.adj[c]:
            if not self.removed[v]:
                self.dec(v)


r = Race(4, 3)
for u, v, w in [(0, 1, 1), (1, 2, 2), (1, 3, 3)]:
    r.add_edge(u, v, w)
r.dec(0)
print(-1 if r.ans == INF else f"min edges for weight 3: {r.ans}")

A4 — Count vertices within radius R of each node (static)

Statement. For every vertex x, output the number of vertices within distance R (including x).

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Per centroid, store a sorted list of distances of all component vertices, plus a per-branch sorted list. - For query at x: over ancestors c, add count(dist(c,·) ≤ R − dist(x,c)) from the full list, subtract the same from the branch containing x.

Go

package main

import (
    "fmt"
    "sort"
)

type CT struct {
    adj      [][]int
    removed  []bool
    size     []int
    cpar     []int
    full     [][]int // per centroid: sorted distances of all comp vertices
    // for each vertex x and ancestor index: dist and branch sorted list reference
    ancC     [][]int // ancestor centroids of x
    ancD     [][]int // dist(x, ancestor)
    branch   [][][]int // per centroid: list of sorted-branch arrays; we store branch id per vertex
    ancB     [][][]int // for x: reference to its branch sorted array at each ancestor
}

func (c *CT) cs(u, p int) int {
    c.size[u] = 1
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] {
            c.size[u] += c.cs(v, u)
        }
    }
    return c.size[u]
}
func (c *CT) fc(u, p, n int) int {
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] && c.size[v] > n/2 {
            return c.fc(v, u, n)
        }
    }
    return u
}
func (c *CT) collect(u, p, d, cen int, arr *[]int, perVertex map[int]int) {
    *arr = append(*arr, d)
    c.ancC[u] = append(c.ancC[u], cen)
    c.ancD[u] = append(c.ancD[u], d)
    for _, v := range c.adj[u] {
        if v != p && !c.removed[v] {
            c.collect(v, u, d+1, cen, arr, perVertex)
        }
    }
}
func (c *CT) dec(entry, par int) int {
    n := c.cs(entry, -1)
    cen := c.fc(entry, -1, n)
    c.removed[cen] = true
    c.cpar[cen] = par
    full := []int{0}
    // record centroid's own distance 0 to itself
    c.ancC[cen] = append(c.ancC[cen], cen)
    c.ancD[cen] = append(c.ancD[cen], 0)
    for _, v := range c.adj[cen] {
        if !c.removed[v] {
            branch := []int{}
            c.collect(v, cen, 1, cen, &branch, nil)
            full = append(full, branch...)
            sort.Ints(branch)
            c.branch[cen] = append(c.branch[cen], branch)
        }
    }
    sort.Ints(full)
    c.full[cen] = full
    for _, v := range c.adj[cen] {
        if !c.removed[v] {
            c.dec(v, cen)
        }
    }
    return cen
}

func countLeqVal(a []int, x int) int {
    return sort.SearchInts(a, x+1)
}

func main() {
    // Demonstrates the FULL-list contribution (branch subtraction omitted for brevity;
    // a complete solution also subtracts the same-branch over-count).
    N, R := 5, 2
    c := &CT{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N),
        cpar: make([]int, N), full: make([][]int, N),
        ancC: make([][]int, N), ancD: make([][]int, N), branch: make([][][]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {3, 4}} {
        c.adj[e[0]] = append(c.adj[e[0]], e[1])
        c.adj[e[1]] = append(c.adj[e[1]], e[0])
    }
    c.dec(0, -1)
    for x := 0; x < N; x++ {
        cnt := 0
        for i, cen := range c.ancC[x] {
            d := c.ancD[x][i]
            if R-d >= 0 {
                cnt += countLeqVal(c.full[cen], R-d)
            }
        }
        fmt.Printf("approx within R of %d (incl over-count): %d\n", x, cnt)
    }
}

Note: the Go version above shows the full-list contribution; a complete solution subtracts, at each ancestor, the count from the branch that contains x. The Java/Python versions below include that subtraction.

Java

import java.util.*;

public class WithinRadius {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size;
    List<long[]>[] full;       // per centroid: sorted distances (as long for binary search)
    // per vertex x: list of {centroid, dist, branchId}
    List<int[]>[] anc;
    Map<Integer, List<int[]>> branchSorted = new HashMap<>(); // centroid -> list of sorted branch arrays

    @SuppressWarnings("unchecked")
    WithinRadius(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
        full = new List[n];
        anc = new List[n];
        for (int i = 0; i < n; i++) { anc[i] = new ArrayList<>(); }
    }
    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }
    int cs(int u, int p) { size[u] = 1; for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += cs(v, u); return size[u]; }
    int fc(int u, int p, int n) { for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return fc(v, u, n); return u; }

    void collect(int u, int p, int d, int cen, int branchId, List<Integer> arr) {
        arr.add(d);
        anc[u].add(new int[]{cen, d, branchId});
        for (int v : adj.get(u)) if (v != p && !removed[v]) collect(v, u, d + 1, cen, branchId, arr);
    }

    void dec(int entry) {
        int n = cs(entry, -1);
        int cen = fc(entry, -1, n);
        removed[cen] = true;
        List<Integer> allD = new ArrayList<>(); allD.add(0);
        anc[cen].add(new int[]{cen, 0, -1});
        List<int[]> branches = new ArrayList<>();
        int bid = 0;
        for (int v : adj.get(cen)) if (!removed[v]) {
            List<Integer> br = new ArrayList<>();
            collect(v, cen, 1, cen, bid, br);
            allD.addAll(br);
            int[] arr = br.stream().mapToInt(Integer::intValue).sorted().toArray();
            branches.add(arr);
            bid++;
        }
        int[] fullArr = allD.stream().mapToInt(Integer::intValue).sorted().toArray();
        fullStore(cen, fullArr, branches);
        for (int v : adj.get(cen)) if (!removed[v]) dec(v);
    }

    Map<Integer, int[]> fullArrays = new HashMap<>();
    Map<Integer, List<int[]>> branchArrays = new HashMap<>();
    void fullStore(int cen, int[] f, List<int[]> br) { fullArrays.put(cen, f); branchArrays.put(cen, br); }

    static int countLeq(int[] a, int x) {
        if (x < 0) return 0;
        int lo = 0, hi = a.length;
        while (lo < hi) { int m = (lo + hi) / 2; if (a[m] <= x) lo = m + 1; else hi = m; }
        return lo;
    }

    int query(int x, int R) {
        int cnt = 0;
        for (int[] e : anc[x]) {
            int cen = e[0], d = e[1], bid = e[2];
            cnt += countLeq(fullArrays.get(cen), R - d);
            if (bid >= 0) cnt -= countLeq(branchArrays.get(cen).get(bid), R - d);
        }
        return cnt;
    }

    public static void main(String[] args) {
        WithinRadius s = new WithinRadius(5);
        int[][] e = {{0,1},{1,2},{1,3},{3,4}};
        for (int[] x : e) s.addEdge(x[0], x[1]);
        s.dec(0);
        for (int x = 0; x < 5; x++) System.out.println("within 2 of " + x + ": " + s.query(x, 2));
    }
}

Python

import sys
from bisect import bisect_right
sys.setrecursionlimit(1 << 20)


class WithinRadius:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.full = {}        # centroid -> sorted distance list
        self.branch = {}      # centroid -> list of sorted branch lists
        self.anc = [[] for _ in range(n)]  # (centroid, dist, branch_id)

    def add_edge(self, u, v):
        self.adj[u].append(v); self.adj[v].append(u)

    def cs(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def collect(self, u, p, d, cen, bid, arr):
        arr.append(d)
        self.anc[u].append((cen, d, bid))
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.collect(v, u, d + 1, cen, bid, arr)

    def dec(self, entry):
        n = self.cs(entry, -1)
        cen = self.fc(entry, -1, n)
        self.removed[cen] = True
        all_d = [0]
        self.anc[cen].append((cen, 0, -1))
        branches = []
        for bid, v in enumerate(v for v in self.adj[cen] if not self.removed[v]):
            br = []
            self.collect(v, cen, 1, cen, bid, br)
            all_d.extend(br)
            branches.append(sorted(br))
        self.full[cen] = sorted(all_d)
        self.branch[cen] = branches
        for v in self.adj[cen]:
            if not self.removed[v]:
                self.dec(v)

    def query(self, x, R):
        cnt = 0
        for cen, d, bid in self.anc[x]:
            if R - d >= 0:
                cnt += bisect_right(self.full[cen], R - d)
                if bid >= 0:
                    cnt -= bisect_right(self.branch[cen][bid], R - d)
        return cnt


s = WithinRadius(5)
for u, v in [(0, 1), (1, 2), (1, 3), (3, 4)]:
    s.add_edge(u, v)
s.dec(0)
for x in range(5):
    print(f"within 2 of {x}: {s.query(x, 2)}")

A5 — Sum of distances over all pairs (decompose-and-aggregate)

Statement. Compute Σ_{u<v} dist(u, v) over all pairs (unweighted). (There is an O(N) rerooting DP for this; here implement it via centroid decomposition to practice the aggregation pattern.)

Constraints. 1 ≤ N ≤ 10⁵.

Hints. - Per centroid, every pair through it contributes dist(u,c) + dist(v,c). Use (count, sumDist) running aggregates across branches; subtract same-branch contributions.

Go

package main

import "fmt"

type Sol struct {
    adj     [][]int
    removed []bool
    size    []int
    ans     int64
}

func (s *Sol) cs(u, p int) int {
    s.size[u] = 1
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.size[u] += s.cs(v, u)
        }
    }
    return s.size[u]
}
func (s *Sol) fc(u, p, n int) int {
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] && s.size[v] > n/2 {
            return s.fc(v, u, n)
        }
    }
    return u
}
func (s *Sol) gather(u, p, d int, out *[]int) {
    *out = append(*out, d)
    for _, v := range s.adj[u] {
        if v != p && !s.removed[v] {
            s.gather(v, u, d+1, out)
        }
    }
}

// sum over pairs (i<j) of d[i]+d[j] = (len-1) * sum(d)
func pairSum(d []int) int64 {
    var sum int64
    for _, x := range d {
        sum += int64(x)
    }
    return int64(len(d)-1) * sum
}

func (s *Sol) dec(entry int) {
    n := s.cs(entry, -1)
    c := s.fc(entry, -1, n)
    s.removed[c] = true
    all := []int{0}
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            var br []int
            s.gather(v, c, 1, &br)
            all = append(all, br...)
            s.ans -= pairSum(br)
        }
    }
    s.ans += pairSum(all)
    for _, v := range s.adj[c] {
        if !s.removed[v] {
            s.dec(v)
        }
    }
}

func main() {
    N := 5
    s := &Sol{adj: make([][]int, N), removed: make([]bool, N), size: make([]int, N)}
    for _, e := range [][2]int{{0, 1}, {1, 2}, {1, 3}, {3, 4}} {
        s.adj[e[0]] = append(s.adj[e[0]], e[1])
        s.adj[e[1]] = append(s.adj[e[1]], e[0])
    }
    s.dec(0)
    fmt.Println("sum of all pairwise distances:", s.ans)
}

Java

import java.util.*;

public class SumDistances {
    List<List<Integer>> adj;
    boolean[] removed;
    int[] size;
    long ans = 0;

    SumDistances(int n) {
        adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        removed = new boolean[n];
        size = new int[n];
    }
    void addEdge(int u, int v) { adj.get(u).add(v); adj.get(v).add(u); }
    int cs(int u, int p) { size[u] = 1; for (int v : adj.get(u)) if (v != p && !removed[v]) size[u] += cs(v, u); return size[u]; }
    int fc(int u, int p, int n) { for (int v : adj.get(u)) if (v != p && !removed[v] && size[v] > n / 2) return fc(v, u, n); return u; }
    void gather(int u, int p, int d, List<Integer> out) { out.add(d); for (int v : adj.get(u)) if (v != p && !removed[v]) gather(v, u, d + 1, out); }
    long pairSum(List<Integer> d) {
        long sum = 0; for (int x : d) sum += x;
        return (long)(d.size() - 1) * sum;
    }
    void dec(int entry) {
        int n = cs(entry, -1);
        int c = fc(entry, -1, n);
        removed[c] = true;
        List<Integer> all = new ArrayList<>(); all.add(0);
        for (int v : adj.get(c)) if (!removed[v]) {
            List<Integer> br = new ArrayList<>();
            gather(v, c, 1, br);
            all.addAll(br);
            ans -= pairSum(br);
        }
        ans += pairSum(all);
        for (int v : adj.get(c)) if (!removed[v]) dec(v);
    }

    public static void main(String[] args) {
        SumDistances s = new SumDistances(5);
        int[][] e = {{0,1},{1,2},{1,3},{3,4}};
        for (int[] x : e) s.addEdge(x[0], x[1]);
        s.dec(0);
        System.out.println("sum of all pairwise distances: " + s.ans);
    }
}

Python

import sys
sys.setrecursionlimit(1 << 20)


class SumDistances:
    def __init__(self, n):
        self.adj = [[] for _ in range(n)]
        self.removed = [False] * n
        self.size = [0] * n
        self.ans = 0

    def add_edge(self, u, v):
        self.adj[u].append(v); self.adj[v].append(u)

    def cs(self, u, p):
        self.size[u] = 1
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.size[u] += self.cs(v, u)
        return self.size[u]

    def fc(self, u, p, n):
        for v in self.adj[u]:
            if v != p and not self.removed[v] and self.size[v] > n // 2:
                return self.fc(v, u, n)
        return u

    def gather(self, u, p, d, out):
        out.append(d)
        for v in self.adj[u]:
            if v != p and not self.removed[v]:
                self.gather(v, u, d + 1, out)

    @staticmethod
    def pair_sum(d):
        return (len(d) - 1) * sum(d)

    def dec(self, entry):
        n = self.cs(entry, -1)
        c = self.fc(entry, -1, n)
        self.removed[c] = True
        all_d = [0]
        for v in self.adj[c]:
            if not self.removed[v]:
                br = []
                self.gather(v, c, 1, br)
                all_d.extend(br)
                self.ans -= self.pair_sum(br)
        self.ans += self.pair_sum(all_d)
        for v in self.adj[c]:
            if not self.removed[v]:
                self.dec(v)


s = SumDistances(5)
for u, v in [(0, 1), (1, 2), (1, 3), (3, 4)]:
    s.add_edge(u, v)
s.dec(0)
print("sum of all pairwise distances:", s.ans)

Benchmark Task

Statement. Build the centroid tree and count pairs at distance ≤ K for: 1. a random tree of N = 10⁶, 2. a path graph (line) of N = 10⁶ — the recursion-depth stress case, 3. a star graph of N = 10⁶ — the wide fan-out case.

Measure build time and total query time. Verify the centroid-tree height is ≤ ⌊log₂ N⌋ + 1 ≈ 21 in every case.

Constraints. N up to 10⁶; you must avoid O(N) native recursion depth on the path graph.

Hints. - Convert computeSize and gather to iterative DFS with an explicit stack to survive the path graph. - Reuse buffers; avoid per-centroid allocations in hot loops. - For the path graph, confirm the height stays logarithmic — if it explodes, you have a stale-size bug. - Expected: build O(N log N) should run in a few seconds in Go/Java for N = 10⁶; Python will need PyPy or iterative DFS plus careful buffer reuse.

Measurement skeleton (Go):

package main

import (
    "fmt"
    "math/bits"
    "time"
)

func buildLine(n int) [][]int {
    adj := make([][]int, n)
    for i := 0; i+1 < n; i++ {
        adj[i] = append(adj[i], i+1)
        adj[i+1] = append(adj[i+1], i)
    }
    return adj
}

func main() {
    n := 1 << 20 // ~10^6
    adj := buildLine(n)
    start := time.Now()
    // Run your iterative-DFS centroid decomposition + count here; record height.
    _ = adj
    fmt.Printf("expected max height <= %d\n", bits.Len(uint(n))) // ~21
    fmt.Println("elapsed:", time.Since(start))
}

Measurement skeleton (Java):

public class Benchmark {
    public static void main(String[] args) {
        int n = 1 << 20;
        long start = System.nanoTime();
        // Build line graph + iterative-DFS centroid decomposition; record height.
        int expectedMaxHeight = 32 - Integer.numberOfLeadingZeros(n); // ~21
        System.out.println("expected max height <= " + expectedMaxHeight);
        System.out.println("elapsed ms: " + (System.nanoTime() - start) / 1_000_000);
    }
}

Measurement skeleton (Python):

import time

n = 1 << 20  # ~10^6
start = time.perf_counter()
# Build line graph + iterative-DFS centroid decomposition; record height.
expected_max_height = n.bit_length()  # ~21
print("expected max height <=", expected_max_height)
print("elapsed s:", time.perf_counter() - start)

What to report: build time and query time for each of the three shapes, peak memory, and the observed maximum centroid-tree height (must be ≤ ⌊log₂ N⌋ + 1). A height far above 21 on any shape is a definitive signal of a stale-size or removed[] bug.