Last CF contest, I solved A~C really quickly, but got stuck for over an hour on a rerooting dp problem. In this blog, I want to learn how to do reroot dp!

When to reroot dp?

(Disclaimer: I will refer uu as the parent node, and v,cv, c as the child node)
Reroot DP occurs when the problem wants a answer that would require making each node as the root of the tree.
You should be able to calculate one of the answers in maybe O(n)O(n) time, and is able to transition subtree/outside subtree informations in less than the time to construct them individually with the help of some information gathered during the calculation of the first answer (subtree information, depth…etc).

Lets check out a basic problem to understand reroot dp more:

LeetCode 834. Sum of Distances in Tree

The problem is basically: For every node ii, return the sum of depth if ii is the root of the tree.
The first step for reroot dp problems is to first determine the answer of a root, lets try to find the answer for node 00.
This is quite trivial, we denote sum[u]sum[u] as the depth sum of subtree uu, we can maintain depth[v]=depth[u]+1depth[v] = depth[u] + 1 with depth[0]=0depth[0] = 0, and sum[u]=depth[u]+sum[v]sum[u] = depth[u] + \sum sum[v] with a simple dfs like this:

1
2
3
4
5
6
7
8
9
10
function<void(int, int)> dfs1 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
depth[v] = depth[u] + 1;
dfs1(v, u);
dp[u] += dp[v];
}
}
dp[u] += depth[u];
};

But when thinking of how to build the answer for the first problem, you also want to start thinking about what information that you need in the rerooting process that can be maintained in the first dfs.
In this problem, we can start thinking of the transition process, then we can know what we want to track in first dfs.

depthchange

The numbers labeled in red is the depth when 00 is the root, and green is when 22 is the root. Notice the nodes inside the yellow circle (nodes that aren’t in the subtree of 22) all increased 1 depth, while the nodes inside the purple circle (nodes that are in the subtree of 22) all decreased 1 depth.

depthchange2

Another example is from 22 to 33, where green is the depth when 22 is the root, and blue is 33. You can also see the same transition. we can thus determine the dp transition between nodes:

Denote dp[i]dp[i] as the answer with ii as the root, then

dp[v]=dp[u]+(nsubtree[v])subtree[v]=dp[u]+n(2subtree[v])dp[v] = dp[u] + (n - subtree[v]) - subtree[v] = dp[u] + n - (2 \cdot subtree[v])

With dp[0]=sum[0]dp[0] = sum[0]. We can do this transition with another dfs.

Notice that we can also precalculate subtree[v] during the first dfs.

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
vector<int> graph[n];
for(int i = 0; i < edges.size(); i++) {
graph[edges[i][0]].pb(edges[i][1]);
graph[edges[i][1]].pb(edges[i][0]);
}
vector<int> depth(n, 0);
vector<int> dp(n, 0);
vector<int> subtree(n, 0);
function<void(int, int)> dfs1 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
depth[v] = depth[u] + 1;
dfs1(v, u);
dp[u] += dp[v];
subtree[u] += subtree[v];
}
}
dp[u] += depth[u];
subtree[u] += 1;
};
dfs1(0, -1);
for(int i = 1; i < n; i++) dp[i] = 0;
//debug(dp, depth, subtree);
function<void(int, int)> dfs2 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
dp[v] = dp[u] + (n - subtree[v]) - subtree[v];
// = dp[u] + n - (2 * subtree[v]);
dfs2(v, u);
}
}
};
dfs2(0, -1);
return dp;
}

Time Complexity: O(n)O(n)

Atcoder Edu DP Contest V - Subtree

This problem is much less trivial than the first one I would say, but lets think about how to get the answer for an initial root.
We can let dp1[i]dp_1[i] denote the number of ways to color the subtree satisfying the condition plus no color on everything. The reason we want to include no coloring is to make the dp1dp_1 transition much easier, as we can notice that
dp1[u]=1+dp1[v]dp_1[u] = 1 + \prod dp_1[v] with dp[leaf]=2dp[leaf] = 2 (color it or not).
We can obtain the answer for root is dp1[root]1dp_1[root] - 1 (removing the one where nothing is colored).

Now, lets think of the transition in our second dfs:

Denote dp2[i]dp_2[i] as the number of ways to color the tree after removing iis subtree. This may seem pretty sudden, but it makes sense if you fully understood reroot dp (which I didn’t so I struggled on this question for a long time)

In a reroot dp problem, you can treat the whole tree as two components: the subtree of a node ii, and other parts of the tree. Since when you make ii the root instead, all the parts that didn’t belong to the subtree will then be inside the subtree (with the original parent node as its child), so if we can maintain the value for this newly added part, we can calculate the answer similar to how we got the first root (as we would have the values for all iis child).

dp2[v]dp_2[v] here basically means the dp1dp_1 we had, but for node uu when vv is the root instead.

In the picture, the red circle is dp2[2]dp_2[2], and the green circle is dp1[2]dp_1[2].

subtreecomp1

subtreecomp2

In this problem, the answer for node uu is obviously just (dp1[v]dp2[u])=(dp1[u]1)dp2[u](\sum dp_1[v] \cdot dp_2[u]) = (dp_1[u] - 1) \cdot dp_2[u]
(try to relating it to how we got the answer for root 11, really helps understanding the concept!)

Now, dp2[root]=1dp_2[root] = 1, how about others?

subtree3

dp2[2]dp_2[2] is cricled in red, and dp2[3]dp_2[3] is circled in green. We can see that dp[3]dp[3] added the nodes that were in the subtree before, but not in the current subtree, which is all the siblings of 3. We can write out the transision as:

dp2[v]=dp2[u](c,cvdp1[c])+1=dp2[u](dp1[u]1dp1[v])+1\displaystyle dp_2[v] = dp_2[u] \cdot (\sum_{c, c \neq v} dp_1[c]) + 1 = dp_2[u] \cdot (\frac{dp_1[u] - 1}{dp_1[v]}) + 1

(I again, want you to try relating this transition to what we did with dp1dp_1)
Which unfortunately, the first one would TLE, and we cannot do the second one because of modular division (and mm isn’t guaranteed to be prime so it’s hard to find inverse).

Fontunately, we can calculate c,cvdp1[c]\sum_{c, c \neq v} dp_1[c] by making prefix/suffix products, this way we can avoid the troublesome division, yay! (and we can also calculate it during the first dfs, which is pretty nice!)

code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
int mabs(int a, int mod) { 
return (a % mod + mod) % mod;
}
int mmul(int a, int b, int mod) {
return mabs((a % mod) * (b % mod), mod);
}
int madd(int a, int b, int mod) { // a + b
return mabs(a % mod + b % mod, mod);
}
int mmin(int a, int b, int mod) { // a - b
return mabs(a % mod - b % mod, mod);
}
int fastpow(int a, int n, int mod) { // calculate a^n % mod
if(n == 0) return 1;
int half = fastpow(a, n >> 1, mod);
if(n & 1) return mmul(mmul(half, half, mod), a, mod);
else return mmul(half, half, mod);
}
int mdiv(int a, int b, int mod) { // (a / b) % mod
return mmul(a, fastpow(b,mod - 2, mod), mod);
}
void solve() {
int n, m;
cin >> n >> m;
vector<int> graph[n + 1];
for(int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
graph[x].pb(y);
graph[y].pb(x);
}
vector<int> dp1(n + 1, 1);
vector<int> prefix[n + 1];
vector<int> suffix[n + 1];
function<void(int, int)> dfs1 = [&](int u, int p) {
if(graph[u].size() == 1 && graph[u][0] == p) { // leaf
dp1[u] = 2;
return;
}
prefix[u].pb(1);
suffix[u].pb(1);
for(auto v : graph[u]) {
if(v != p) {
dfs1(v, u);
dp1[u] = mmul(dp1[u], dp1[v], m);
prefix[u].pb(dp1[u]);
}
}
int tmp = 1;
for(int i = graph[u].size() - 1; i >= 0; i--) {
int v = graph[u][i];
if(v != p) {
tmp = mmul(tmp, dp1[v], m);
suffix[u].pb(tmp);
}
}
prefix[u].pb(1);
suffix[u].pb(1);
reverse(all(suffix[u]));
dp1[u] += 1;
};
dfs1(1, -1);
vector<int> dp2(n + 1, 0);
dp2[1] = 1;
function<void(int, int)> dfs2 = [&](int u, int p) {
bool flg = 0;
for(int i = 1; i <= graph[u].size(); i ++) {
int v = graph[u][i - 1];
if(v != p) {
dp2[v] = madd(mmul(dp2[u], mmul(prefix[u][i - 1 - flg], suffix[u][i + 1 - flg], m), m), 1, m);
// dp2[v] = (dp2[u] * \sum dp1[c] (c is u's child && c != v)) + 1
dfs2(v, u);
} else flg = 1;
}
};
dfs2(1, -1);
for(int i = 1; i <= n; i++) cout << mmul((dp1[i] - 1), dp2[i], m) << endl;
return;
}

Time Complexity: O(n)O(n)

Note: you can also write the first leetcode problem similar to this one too, let dp2[u]dp_2[u] be the sum outside of the subtree uu, and dpdp as the sum of subtree uu, both with uu as root. It’s uglier because you need to update dpdp too here because it changes, but I think writing it like this can make the steps of reroot dp clearer.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// the same as before
dfs1(0, -1);
vector<int> dp2(n, 0);
function<void(int, int)> dfs2 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
dp[v] -= depth[v] * subtree[v];
dp2[v] = dp2[u] + (dp[u] - dp[v] + (n - (2 * subtree[v])));
dfs2(v, u);
}
}
};
dfs2(0, -1);
vector<int> ans;
for(int i = 0; i < n; i++) ans.push_back(dp[i] + dp2[i]);
return ans;

Now, let’s actually solve the problem I was stuck in contest.

CF 1882D. Tree XOR

Let’s first think about how to obtain the answer for the root.
From our root, we can greedily change every child vv into the value our root has with (val[u]val[v])subtree[v](val[u] ⊕ val[v]) ⊕ subtree[v].
We also need to remember that the nodes after the child cc also changed to (val[u]val[v])val[c](val[u] ⊕ val[v]) ⊕ val[c] when calculating.
Let’s denote the value that we want to apply to the subtree[u]subtree[u] as xor_val[u]xor\_val[u], which xor_val[root]xor\_val[root] is 0.
calulating xor_valxor\_val during dfs is quite easy: xor_val[v]=(xor_val[u]val[v])val[1]xor\_val[v] = (xor\_val[u] ⊕ val[v]) ⊕ val[1].
But there is acutally a pretty nice observation here:

treexor

The part circled in yellow is the original value after the operation on node 22, and the red part is the value after node 44.
The value of every xor_val[v]xor\_val[v] is just val[u]val[v]val[u] ⊕ val[v]!
This tells us the value to apply does not change with the root, which makes our lives much easier.
So, the first part of dfs should look like this:

1
2
3
4
5
6
7
8
9
10
11
function<void(int, int)> dfs1 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
xor_val[v] = (val[u] ^ val[v]);
//xor_val[v] = (xor_val[u] ^ val[v]) ^ val[1];
dfs1(v, u);
subtree[u] += subtree[v];
}
}
subtree[u] += 1;
};

and ans[root]=i=1nxor_val[i]subtree[i]ans[root] = \sum_{i = 1}^{n} xor\_val[i] \cdot subtree[i].

Now, for the rerooting part:

(xor_val[i]subtree[i])(xor\_val[i] \cdot subtree[i]) does not change with the root. The only changing ones are the current root (uu), and the next child (vv).

uu will become the child of vv, so xor_val[u]=val[v]val[u]=xor_val[v]xor\_val[u] = val[v] ⊕ val[u] = xor\_val[v], and the subtree size is just nsubtree[v]n - subtree[v].
For vv, (xor_val[v]subtree[v])(xor\_val[v] \cdot subtree[v]) is just 00 since it’s the new root.
Combine them together, we have the following transition:

ans[v]=ans[u]+(xor_val[v](nsubtree[v]))(xor_val[v]subtree[v])ans[v] = ans[u] + (xor\_val[v] \cdot (n - subtree[v])) - (xor\_val[v] \cdot subtree[v])

Which equals

ans[v]=ans[u]+xor_val[v](2subtree[v])ans[v] = ans[u] + xor\_val[v] \cdot (2 \cdot subtree[v])

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
void solve() {
int n;
cin >> n;
vector<int> val(n + 1, 0);
for(int i = 1; i <= n; i++) {
cin >> val[i];
}
vector<int> graph[n + 1];
for(int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
graph[x].pb(y);
graph[y].pb(x);
}
vector<int> dp1(n + 1, 0);
vector<int> subtree(n + 1, 0);
vector<int> xor_val(n + 1, 0);
xor_val[1] = 0;
function<void(int, int)> dfs1 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
xor_val[v] = (val[u] ^ val[v]);
//xor_val[v] = (xor_val[u] ^ val[v]) ^ val[1];
dfs1(v, u);
subtree[u] += subtree[v];
}
}
subtree[u] += 1;
};
dfs1(1, -1);
vi ans(n + 1);
for(int i = 1; i <= n; i++) ans[1] += (xor_val[i] * subtree[i]);
function<void(int, int)> dfs2 = [&](int u, int p) {
for(auto v : graph[u]) {
if(v != p) {
ans[v] = ans[u] + xor_val[v] * (n - (2 * subtree[v]));
dfs2(v, u);
}
}
};
dfs2(1, -1);
for(int i = 1; i <= n; i++) cout << ans[i] << " ";
cout << endl;
return;
}

Time Complexity: O(n)O(n)