Tree DP (Dynamic Programming on Trees) — Practice Tasks¶
All tasks must be solved in Go, Java, and Python. Each task ships with a precise I/O spec and starter code in all three languages. Implement the post-order DFS (and, where noted, the second rerooting pass). Always test against a brute-force oracle on small trees (try all subsets / all node pairs) before trusting the dp. Reminder: a tree on
nnodes hasn-1edges; guard the parent in the DFS; watch recursion depth on path-shaped trees (raise the limit or go iterative).
Beginner Tasks (5)¶
Task 1 — Subtree sizes¶
Problem. Given a rooted tree (root 0), compute size[v] = number of nodes in the subtree of v (including v).
Input / Output spec. - Read n, then n-1 edges u v. - Print size[0] size[1] … size[n-1] space-separated.
Constraints. 1 ≤ n ≤ 2·10^5. The graph is a tree, 0-indexed.
Hint. size[v] = 1 + Σ size[c] over children. One post-order DFS.
Starter — Go.
package main
import "fmt"
var adj [][]int
var size []int
func dfs(v, p int) {
// TODO: size[v] = 1 + sum of children sizes
}
func main() {
var n int
fmt.Scan(&n)
adj = make([][]int, n)
size = make([]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
dfs(0, -1)
for i, s := range size {
if i > 0 {
fmt.Print(" ")
}
fmt.Print(s)
}
fmt.Println()
}
Starter — Java.
import java.util.*;
public class SubtreeSizes {
static List<Integer>[] adj;
static int[] size;
static void dfs(int v, int p) {
// TODO
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new List[n];
size = new int[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs(0, -1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) { if (i > 0) sb.append(' '); sb.append(size[i]); }
System.out.println(sb);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
size = [0] * n
def dfs(v, p):
# TODO: size[v] = 1 + sum of children
pass
dfs(0, -1)
print(*size)
main()
Task 2 — Subtree value sums¶
Problem. Each node has a value a[v]. Compute sum[v] = sum of values in the subtree of v.
I/O spec. Read n, then a[0..n-1], then n-1 edges. Print sum[0..n-1].
Constraints. 1 ≤ n ≤ 2·10^5, |a[v]| ≤ 10^9. Use 64-bit.
Hint. sum[v] = a[v] + Σ sum[c]. Identical shape to Task 1 with a value instead of 1.
Starter — Go.
package main
import "fmt"
func main() {
var n int
fmt.Scan(&n)
a := make([]int64, n)
for i := range a {
fmt.Scan(&a[i])
}
adj := make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
sum := make([]int64, n)
var dfs func(v, p int)
dfs = func(v, p int) {
// TODO: sum[v] = a[v] + children sums
}
dfs(0, -1)
for i, s := range sum {
if i > 0 {
fmt.Print(" ")
}
fmt.Print(s)
}
fmt.Println()
}
Starter — Java.
import java.util.*;
public class SubtreeSums {
static List<Integer>[] adj;
static long[] a, sum;
static void dfs(int v, int p) {
// TODO
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
a = new long[n]; sum = new long[n];
for (int i = 0; i < n; i++) a[i] = sc.nextLong();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs(0, -1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) { if (i > 0) sb.append(' '); sb.append(sum[i]); }
System.out.println(sb);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
a = list(map(int, input().split()))
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
sm = [0] * n
def dfs(v, p):
# TODO
pass
dfs(0, -1)
print(*sm)
main()
Task 3 — Maximum independent set weight (House Robber III)¶
Problem. Each node has value a[v] ≥ 0. Pick a set with no parent–child pair, maximizing total value.
I/O spec. Read n, a[0..n-1], n-1 edges. Print the maximum.
Constraints. 1 ≤ n ≤ 2·10^5, 0 ≤ a[v] ≤ 10^9. 64-bit.
Hint. Return (excl, incl); incl = a[v] + Σ excl[c], excl = Σ max(...). Answer max at root.
Starter — Go.
package main
import "fmt"
var adj [][]int
var a []int64
func dfs(v, p int) (int64, int64) {
// TODO: return (best excluding v, best including v)
return 0, 0
}
func main() {
var n int
fmt.Scan(&n)
a = make([]int64, n)
for i := range a {
fmt.Scan(&a[i])
}
adj = make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
e, i := dfs(0, -1)
if e > i {
fmt.Println(e)
} else {
fmt.Println(i)
}
}
Starter — Java.
import java.util.*;
public class TreeMISWeight {
static List<Integer>[] adj;
static long[] a;
static long[] dfs(int v, int p) {
// TODO return new long[]{excl, incl}
return new long[]{0, 0};
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
a = new long[n];
for (int i = 0; i < n; i++) a[i] = sc.nextLong();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
long[] r = dfs(0, -1);
System.out.println(Math.max(r[0], r[1]));
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
a = list(map(int, input().split()))
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
def dfs(v, p):
# TODO return (excl, incl)
return 0, 0
print(max(dfs(0, -1)))
main()
Task 4 — Count leaves in each subtree¶
Problem. Compute leaves[v] = number of leaves in the subtree of v. (A leaf has no children; for n = 1, node 0 is a leaf.)
I/O spec. Read n, n-1 edges. Print leaves[0..n-1].
Constraints. 1 ≤ n ≤ 2·10^5.
Hint. If v has no children (besides parent), leaves[v] = 1; else leaves[v] = Σ leaves[c].
Starter — Go.
package main
import "fmt"
func main() {
var n int
fmt.Scan(&n)
adj := make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
leaves := make([]int, n)
var dfs func(v, p int)
dfs = func(v, p int) {
// TODO: count children; if none, leaves[v]=1
}
dfs(0, -1)
for i, l := range leaves {
if i > 0 {
fmt.Print(" ")
}
fmt.Print(l)
}
fmt.Println()
}
Starter — Java.
import java.util.*;
public class SubtreeLeaves {
static List<Integer>[] adj;
static int[] leaves;
static void dfs(int v, int p) {
// TODO
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new List[n];
leaves = new int[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs(0, -1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) { if (i > 0) sb.append(' '); sb.append(leaves[i]); }
System.out.println(sb);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
leaves = [0] * n
def dfs(v, p):
# TODO
pass
dfs(0, -1)
print(*leaves)
main()
Task 5 — Tree diameter (edges)¶
Problem. Print the number of edges on the longest path in the tree.
I/O spec. Read n, n-1 edges. Print the diameter.
Constraints. 1 ≤ n ≤ 2·10^5. For n = 1, diameter is 0.
Hint. down[v] = longest downward chain; global best = max(b1 + b2). Return one chain, record two.
Starter — Go.
package main
import "fmt"
var g [][]int
var best int
func down(v, p int) int {
// TODO: track two largest child chains, update best, return largest
return 0
}
func main() {
var n int
fmt.Scan(&n)
g = make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
g[u] = append(g[u], v)
g[v] = append(g[v], u)
}
down(0, -1)
fmt.Println(best)
}
Starter — Java.
import java.util.*;
public class DiameterEdges {
static List<Integer>[] g;
static int best = 0;
static int down(int v, int p) {
// TODO
return 0;
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
g = new List[n];
for (int i = 0; i < n; i++) g[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
g[u].add(v); g[v].add(u);
}
down(0, -1);
System.out.println(best);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
g = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
g[u].append(v)
g[v].append(u)
best = 0
def down(v, p):
nonlocal best
# TODO
return 0
down(0, -1)
print(best)
main()
Intermediate Tasks (4)¶
Task 6 — Tree MIS by node count (unweighted)¶
Problem. Maximum independent set by count of nodes (each weight 1).
I/O spec. Read n, n-1 edges. Print the maximum number of nodes.
Constraints. 1 ≤ n ≤ 2·10^5.
Hint. Same as Task 3 with a[v] = 1. For a path of n nodes the answer is ⌈n/2⌉.
Starter — Go.
package main
import "fmt"
var adj [][]int
func dfs(v, p int) (int, int) {
excl, incl := 0, 1
for _, c := range adj[v] {
if c == p {
continue
}
// TODO combine
_ = c
}
return excl, incl
}
func main() {
var n int
fmt.Scan(&n)
adj = make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
e, i := dfs(0, -1)
if e > i {
fmt.Println(e)
} else {
fmt.Println(i)
}
}
Starter — Java.
import java.util.*;
public class TreeMISCount {
static List<Integer>[] adj;
static int[] dfs(int v, int p) {
int excl = 0, incl = 1;
for (int c : adj[v]) {
if (c == p) continue;
// TODO combine
}
return new int[]{excl, incl};
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
int[] r = dfs(0, -1);
System.out.println(Math.max(r[0], r[1]));
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
def dfs(v, p):
excl, incl = 0, 1
for c in adj[v]:
if c == p:
continue
# TODO combine
return excl, incl
print(max(dfs(0, -1)))
main()
Task 7 — Count independent sets mod p¶
Problem. Count the number of independent sets of the tree (including the empty set), modulo 10^9 + 7.
I/O spec. Read n, n-1 edges. Print the count mod 1_000_000_007.
Constraints. 1 ≤ n ≤ 2·10^5.
Hint. g1[v] = Π g0[c], g0[v] = Π (g0[c] + g1[c]); answer (g0[root] + g1[root]) mod p. Reduce after each operation.
Starter — Go.
package main
import "fmt"
const MOD = 1_000_000_007
var adj [][]int
func dfs(v, p int) (int64, int64) {
g0, g1 := int64(1), int64(1)
for _, c := range adj[v] {
if c == p {
continue
}
// TODO: c0,c1 := dfs(c,v); g0 = g0*((c0+c1)%MOD)%MOD; g1 = g1*c0%MOD
_ = c
}
return g0, g1
}
func main() {
var n int
fmt.Scan(&n)
adj = make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
g0, g1 := dfs(0, -1)
fmt.Println((g0 + g1) % MOD)
}
Starter — Java.
import java.util.*;
public class CountIndependentSets {
static final long MOD = 1_000_000_007L;
static List<Integer>[] adj;
static long[] dfs(int v, int p) {
long g0 = 1, g1 = 1;
for (int c : adj[v]) {
if (c == p) continue;
// TODO
}
return new long[]{g0, g1};
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
long[] r = dfs(0, -1);
System.out.println((r[0] + r[1]) % MOD);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
MOD = 1_000_000_007
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
def dfs(v, p):
g0, g1 = 1, 1
for c in adj[v]:
if c == p:
continue
# TODO
return g0, g1
g0, g1 = dfs(0, -1)
print((g0 + g1) % MOD)
main()
Task 8 — Sum of distances in tree (rerooting)¶
Problem. For every node v, compute the sum of distances from v to all other nodes. Print all n values.
I/O spec. Read n, n-1 edges. Print ans[0..n-1].
Constraints. 1 ≤ n ≤ 2·10^5. Use 64-bit (ans can reach ~n²). An O(n²) solution will TLE.
Hint. DFS1: cnt[v], down[v] = Σ(down[c] + cnt[c]), ans[0] = down[0]. DFS2: ans[c] = ans[v] - cnt[c] + (n - cnt[c]).
Starter — Go.
package main
import "fmt"
var (
adj [][]int
cnt []int
down []int64
ans []int64
N int
)
func dfs1(v, p int) {
// TODO: cnt[v], down[v]
}
func dfs2(v, p int) {
// TODO: ans[c] = ans[v] - cnt[c] + (N - cnt[c]); recurse
}
func main() {
fmt.Scan(&N)
adj = make([][]int, N)
cnt = make([]int, N)
down = make([]int64, N)
ans = make([]int64, N)
for i := 0; i < N-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
dfs1(0, -1)
ans[0] = down[0]
dfs2(0, -1)
for i, x := range ans {
if i > 0 {
fmt.Print(" ")
}
fmt.Print(x)
}
fmt.Println()
}
Starter — Java.
import java.util.*;
public class SumDistances {
static List<Integer>[] adj;
static int[] cnt;
static long[] down, ans;
static int n;
static void dfs1(int v, int p) { /* TODO */ }
static void dfs2(int v, int p) { /* TODO */ }
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
cnt = new int[n]; down = new long[n]; ans = new long[n];
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs1(0, -1);
ans[0] = down[0];
dfs2(0, -1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) { if (i > 0) sb.append(' '); sb.append(ans[i]); }
System.out.println(sb);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
cnt = [0] * n
down = [0] * n
ans = [0] * n
def dfs1(v, p):
# TODO
pass
def dfs2(v, p):
# TODO
pass
dfs1(0, -1)
ans[0] = down[0]
dfs2(0, -1)
print(*ans)
main()
Advanced Tasks (3)¶
Task 9 — Tree knapsack with parent dependency¶
Problem. Each node has weight w[v] and value val[v]. Selecting v requires selecting par(v) (root free). With budget W, maximize total value.
I/O spec. Read n W, then w[v] val[v] for each node, then n-1 edges. Print the maximum value.
Constraints. 1 ≤ n ≤ 2000, 1 ≤ W ≤ 2000, weights/values up to 10^4. Target O(n·W) (cap loops by subtree size).
Hint. dp[v][j] with v taken; merge children by (max,+) convolution; bound j by min(cnt[v], W) and b by min(cnt[c], W-j).
Starter — Go.
package main
import "fmt"
const NEG = -1 << 30
var (
adj [][]int
w, val []int
cnt []int
dp [][]int
W int
)
func dfs(v, p int) {
// TODO: init dp[v] (v taken), merge each child with min-bounded loops, grow cnt[v]
}
func main() {
var n int
fmt.Scan(&n, &W)
w = make([]int, n)
val = make([]int, n)
for i := 0; i < n; i++ {
fmt.Scan(&w[i], &val[i])
}
adj = make([][]int, n)
cnt = make([]int, n)
dp = make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
dfs(0, -1)
best := 0
for _, x := range dp[0] {
if x > best {
best = x
}
}
fmt.Println(best)
}
Starter — Java.
import java.util.*;
public class TreeKnapsackTask {
static List<Integer>[] adj;
static int[] w, val, cnt;
static int[][] dp;
static int W;
static final int NEG = Integer.MIN_VALUE / 2;
static void dfs(int v, int p) {
// TODO
}
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(); W = sc.nextInt();
w = new int[n]; val = new int[n]; cnt = new int[n]; dp = new int[n][];
for (int i = 0; i < n; i++) { w[i] = sc.nextInt(); val[i] = sc.nextInt(); }
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs(0, -1);
int best = 0;
for (int x : dp[0]) best = Math.max(best, x);
System.out.println(best);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
NEG = float("-inf")
def main():
n, W = map(int, input().split())
w = [0] * n
val = [0] * n
for i in range(n):
w[i], val[i] = map(int, input().split())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
dp = [None] * n
cnt = [0] * n
def dfs(v, p):
# TODO: init row (v taken), merge children with min-bounded loops
pass
dfs(0, -1)
print(max(x for x in dp[0] if x != NEG))
main()
Task 10 — Maximum distance from each node (rerooting, max not sum)¶
Problem. For each node v, compute the maximum distance (in edges) to any other node (the eccentricity of v). Print all n values.
I/O spec. Read n, n-1 edges. Print ecc[0..n-1].
Constraints. 1 ≤ n ≤ 2·10^5. Must be O(n).
Hint. This needs a max reroot: keep down1, down2 (two longest downward chains) per node so a child can use the parent's best chain that does not go through itself. max is not invertible — use the two-best trick. up[c] = 1 + max(up[v], best downward chain of v avoiding c).
Starter — Go.
package main
import "fmt"
var (
adj [][]int
down1, down2 []int // two longest downward chains
who []int // child giving down1
up []int
ans []int
)
func dfs1(v, p int) {
// TODO: fill down1[v], down2[v], who[v]
}
func dfs2(v, p int) {
// TODO: for each child c, up[c] = 1 + max(up[v], (c==who[v]? down2[v] : down1[v]))
// ans[c] = max(up[c], down1[c]); recurse
}
func main() {
var n int
fmt.Scan(&n)
adj = make([][]int, n)
down1 = make([]int, n)
down2 = make([]int, n)
who = make([]int, n)
up = make([]int, n)
ans = make([]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Scan(&u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
dfs1(0, -1)
dfs2(0, -1)
ans[0] = down1[0]
for i := 1; i < n; i++ {
if down1[i] > ans[i] {
ans[i] = down1[i]
}
}
for i, x := range ans {
if i > 0 {
fmt.Print(" ")
}
fmt.Print(x)
}
fmt.Println()
}
Starter — Java.
import java.util.*;
public class Eccentricity {
static List<Integer>[] adj;
static int[] down1, down2, who, up, ans;
static void dfs1(int v, int p) { /* TODO */ }
static void dfs2(int v, int p) { /* TODO */ }
@SuppressWarnings("unchecked")
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
down1 = new int[n]; down2 = new int[n]; who = new int[n];
up = new int[n]; ans = new int[n];
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt(), v = sc.nextInt();
adj[u].add(v); adj[v].add(u);
}
dfs1(0, -1);
dfs2(0, -1);
for (int i = 0; i < n; i++) ans[i] = Math.max(up[i], down1[i]);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) { if (i > 0) sb.append(' '); sb.append(ans[i]); }
System.out.println(sb);
}
}
Starter — Python.
import sys
sys.setrecursionlimit(1 << 25)
input = sys.stdin.readline
def main():
n = int(input())
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
down1 = [0] * n
down2 = [0] * n
who = [-1] * n
up = [0] * n
def dfs1(v, p):
# TODO fill down1, down2, who
pass
def dfs2(v, p):
# TODO up[c] = 1 + max(up[v], down2[v] if c==who[v] else down1[v]); recurse
pass
dfs1(0, -1)
dfs2(0, -1)
ans = [max(up[i], down1[i]) for i in range(n)]
print(*ans)
main()
Task 11 — Number of nodes at each distance: weighted MIS on a tree with iterative DFS¶
Problem. Same as Task 3 (maximum-weight independent set), but n can be up to 10^6 and the tree may be a path — a recursive solution will overflow the stack. Implement it iteratively (explicit stack, reverse pre-order).
I/O spec. Read n, a[0..n-1], n-1 edges. Print the maximum-weight independent set.
Constraints. 1 ≤ n ≤ 10^6, 0 ≤ a[v] ≤ 10^9. 64-bit. Recursion is not allowed (assume the stack is small).
Hint. Build a pre-order with an explicit stack, record parent[], then iterate the order in reverse, accumulating excl[v]/incl[v] from already-processed children.
Starter — Go.
package main
import (
"bufio"
"fmt"
"os"
)
func main() {
reader := bufio.NewReader(os.Stdin)
var n int
fmt.Fscan(reader, &n)
a := make([]int64, n)
for i := range a {
fmt.Fscan(reader, &a[i])
}
adj := make([][]int, n)
for i := 0; i < n-1; i++ {
var u, v int
fmt.Fscan(reader, &u, &v)
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
// TODO: explicit-stack pre-order -> order[], parent[]
// then reverse loop computing excl[], incl[]
fmt.Println(0) // replace
}
Starter — Java.
import java.util.*;
import java.io.*;
public class IterativeMIS {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StreamTokenizer st = new StreamTokenizer(br);
st.nextToken(); int n = (int) st.nval;
long[] a = new long[n];
for (int i = 0; i < n; i++) { st.nextToken(); a[i] = (long) st.nval; }
List<Integer>[] adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
st.nextToken(); int u = (int) st.nval;
st.nextToken(); int v = (int) st.nval;
adj[u].add(v); adj[v].add(u);
}
// TODO: iterative pre-order, then reverse-order dp
System.out.println(0); // replace
}
}
Starter — Python.
import sys
input = sys.stdin.buffer.read
def main():
data = input().split()
idx = 0
n = int(data[idx]); idx += 1
a = [int(data[idx + i]) for i in range(n)]; idx += n
adj = [[] for _ in range(n)]
for _ in range(n - 1):
u = int(data[idx]); v = int(data[idx + 1]); idx += 2
adj[u].append(v)
adj[v].append(u)
# TODO: explicit-stack pre-order -> order, parent; reverse loop -> excl, incl
print(0) # replace
main()
Evaluation Criteria¶
- Correctness: match a brute-force oracle (subset/pair enumeration) on random trees with
n ≤ 14. - Complexity: one-pass tasks
O(n); rerootingO(n); tree knapsackO(n·W); no accidentalO(n²)reroot. - Robustness: handle
n = 1, a path tree (depth stress), and a star tree (wide combine). - Arithmetic: 64-bit for sums/distances; modulus reduced after every operation in counting tasks.
- Stack safety: Task 11 must run iteratively without overflow on a
10^6-node path.