Recently when I was doing Universal Cup Round 5, I got stuck a tree problem A as I realised that my solution required way too much memory. However, after the contest, I realised that there was a way that I can reduce a lot of memory using HLD. So here I am with my idea...
Structure of Tree DP
Most tree DP problems follow the following structure.
struct S {
// return value of DP
};
S init(int u) {
// initialise the base state of dp[u]
}
S merge(S left, S right) {
// returns the new dp state where old state is left and transition using right
}
S dp(int u, int p) {
S res = init(u);
for (int v : adj[u]) {
if (v == p) continue;
res = merge(res, dp(v, u));
}
return res;
}
int main() {
dp(1, -1);
}
An example of a tree DP using this structure is maximum independent set (MIS).
Suppose struct $$$S$$$ requires $$$|S|$$$ bytes and our tree has N vertices. Then this naive implementation of tree dp requires $$$O(N\cdot |S|)$$$ memory as res
of the parent is stored in the recursion stack as we recurse down to the leaves. This is fine for many problems as most of the time, $$$|S| = O(1)$$$, however in the case of the above question, $$$|S| = 25^2\cdot 24$$$ bytes and $$$N = 10^5$$$, which will require around $$$1.5$$$ gigabytes of memory, which is too much to pass the memory limit of $$$256$$$ megabytes.
Optimization
We try to make use of the idea of HLD and visit the visit the vertex with the largest subtree size first.
struct S {
// return value of DP
int take, notTake;
};
S init(int u) {
// initialise the base state of dp[u]
return {1, 0};
}
S merge(S left, S right) {
// returns the new dp state where old state is left and transition using right
return {left.take + right.notTake, left.notTake + max(right.take, right.notTake)};
}
int sub[MAXN];
void getSub(int u, int p) {
sub[u] = 1;
pair<int, int> heavy = {-1, -1};
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (v == p) continue;
heavy = max(heavy, {sub[v], i});
}
// make the vertex with the largest subtree size the first
if (heavy.first != -1) {
swap(adj[u][0], adj[u][heavy.second]);
}
}
S dp(int u, int p) {
// do not initialize yet
S res;
bool hasInit = false;
for (int v : adj[u]) {
if (v == p) continue;
if (!hasInit) {
res = init();
hasInit = true;
}
res = merge(res, dp(v, u));
}
if (!hasInit) {
res = init();
hasInit = true;
}
return res;
}
int main() {
getSub(1, -1);
dp(1, -1);
}