Recently one my friend got a TL verdict in some problem, and another my friend tried to help him to overcome it. Then I joined, got shocked by the results they got and we together stabilized the approach they came to.
By "deep recursion" we will mean a recursion with depth >= 100000. If it's less than 1000, this approach doesn't work at all, if it's between 1000 and 100000, most likely the speedup won't be so high so that will be decisive.
Let's consider this simple DFS function (it just counts the number of vertices):
private int dfs(int x, int p, List<Integer>[] g) {
int result = 1;
for (int y : g[x]) {
if (y == p) {
continue;
}
result += dfs(y, x, g);
}
return result;
}
generate a chain tree with 500000 vertices:
int n = 500000;
List<Integer>[] g = new List[n];
for (int i = 0; i < n; i++) {
g[i] = new ArrayList<>();
}
for (int i = 1; i < n; i++) {
int p = i - 1;
g[i].add(p);
g[p].add(i);
}
and run it:
long t1 = System.nanoTime();
int result = dfs(0, -1, g);
long t2 = System.nanoTime();
System.out.println("time = " + (t2 - t1) / 1.0e9 + " sec, result = " + result);
On my computer it runs 1.8 sec. Too slow for such a simple function, isn't it?
Let's speedup it. First, add two dummy parameters curDepth
and maxDepth
to DFS:
private int dfs(int x, int p, List<Integer>[] g, int curDepth, int maxDepth) {
if (curDepth > maxDepth) {
return 0;
}
int result = 1;
for (int y : g[x]) {
if (y == p) {
continue;
}
result += dfs(y, x, g, curDepth + 1, maxDepth);
}
return result;
}
and then... OMG WTF
long t1 = System.nanoTime();
for (int i = 0; i < 100000; i++) {
dfs(0, -1, g, 0, 0); // JIT please
}
int result = dfs(0, -1, g, 0, Integer.MAX_VALUE);
long t2 = System.nanoTime();
System.out.println("time = " + (t2 - t1) / 1.0e9 + " sec, result = " + result);
Now it works for 0.1 sec :)
When some method in Java is called too frequently, JIT optimizes it. It gets recompiled in runtime and on every new call the recompiled version will be called. However, the method cannot be recompiled while it is being executed. See this StackOverflow thread for more info. Maybe some Java expert can add something?
So in the first example, we enter into not optimized version of DFS and use it all the time. In the second example we do a lot of very short DFS calls, it for sure gets optimized, and finally we use optimized version to solve the actual problem.
Make sure the warming up does many calls but doesn't do many operations. In this example it's wise to do these fake DFS calls from a leaf and stop after the first recursive call.
Make sure all branches in your recursive function are actually run. It would be wrong to call dfs(0, -1, g, 0, -1)
because the execution finishes quickly on if (curDepth > maxDepth)
, and the remaining code with iteration over adjacency list is not run and therefore is not optimized.
It doesn't work very well on Codeforces:
But it works just fine on Ideone:
- Java 8: https://ideone.com/2dXc6b vs https://ideone.com/EG2opW
- Java 12: https://ideone.com/IZ1olp vs https://ideone.com/CWde2I
and also works on:
- Atcoder
- CSES
- Yandex.Contest in Java 8 compiler
- ...
Maybe it's Windows/Linux or 32-bit/64-bit or OpenJDK/Oracle JDK differences, I haven't tested it yet.
Once more, it works only when recursion has depth >= 100000, and only on specific sites. But maybe it will help someone, and it is funny anyway.