I am requesting for help from python experts (e.g: pajenegod) regarding this issue.
Article: https://usaco.guide/general/choosing-lang
Over here in this article, they mentioned that: "A comparable Python solution only passes the first five test cases:", "After some optimizations, this solution still passes only the first seven test cases:", "We are not sure whether it is possible to modify the above approach to receive full credit (please let us know if you manage to succeed)! Of course, it is possible to pass this problem in Python using DSU (a Gold topic):"
So I went ahead to try to optimise the approach (Binary Search with DFS) but I could only get 9/10 testcases to pass reliably. Very rarely, I will have 10/10 testcases passed, with the 10th testcase having ~3990 ms. I wonder if it is possible to get 10/10 testcases to pass reliably?
I have tried many approaches, including speeding up IO, using list comprehensions whenever possible instead of for loops, using bitwise operators to avoid slow tuple sorting.
I have also profiled my code and found that the valid() function is the one that is the bottleneck.
Here is the code:
from operator import methodcaller
def main():
lines = open("wormsort.in","rb").readlines()
n,m = map(int,lines[0].split())
loc = [*map(int,lines[1].split())]
edges = [[] for _ in range(n)]
lo,hi,mask = 0,m,0b11111111111111111
def valid(loc, mid):
component = [-1]*n
numcomps = 0
for i in range(n):
if component[i] < 0:
todo = [i]
component[i] = numcomps
while todo:
for child in [x[0] for x in edges[todo.pop()] if component[x[0]] < 0 and x[1] < mid]:
component[child] = numcomps
todo.append(child)
numcomps += 1
if component[i] != component[loc[i] - 1]:
return False
return True
# bitwise to avoid tuple sort
all_edges = [*map(lambda x: int(x[2]) << 34 ^ int(x[0]) << 17 ^ int(x[1]),
map(methodcaller("split", b" "), lines[2:]))]
all_edges.sort(reverse=True)
for i, val in enumerate(all_edges):
rhs = (val & mask) - 1
lhs = ((val >> 17) & mask) - 1
edges[lhs].append((rhs,i))
edges[rhs].append((lhs,i))
while lo != hi:
mid = (lo + hi) // 2
if valid(loc, mid):
hi = mid
else:
lo = mid+1
open("wormsort.out","w").write(f"{-1 if lo == 0 else all_edges[lo-1] >> 34}\n")
main()
Any tips or advice on how to speed this up would be greatly appreciated. Thank you!