Help needed! Python code optimisation for USACO Silver: Wormhole Sort
Difference between en1 and en2, changed 0 character(s)
I am requesting for help from python experts (e.g: [user:pajenegod,2022-11-29]) regarding this issue.↵

Article: [https://usaco.guide/general/choosing-lang](https://usaco.guide/general/choosing-lang)↵

Over here in this article, they mentioned that: "A comparable Python solution only passes the first five test cases:", "After some optimizations, this solution still passes only the first seven test cases:", "We are not sure whether it is possible to modify the above approach to receive full credit (please let us know if you manage to succeed)! Of course, it is possible to pass this problem in Python using DSU (a Gold topic):"↵

So I went ahead to try to optimise the approach (Binary Search with DFS) but I could only get 9/10 testcases to pass reliably. Very rarely, I will have 10/10 testcases passed, with the 10th testcase having ~3990 ms. I wonder if it is possible to get 10/10 testcases to pass reliably?↵

I have tried many approaches, including speeding up IO, using list comprehensions whenever possible instead of for loops, using bitwise operators to avoid slow tuple sorting.↵

I have also profiled my code and found that the valid() function is the one that is the bottleneck.↵


Here is the code:↵

~~~~~↵
from operator import methodcaller↵

def main():↵
lines = open("wormsort.in","rb").readlines()↵
n,m = map(int,lines[0].split())↵
loc = [*map(int,lines[1].split())]↵
edges = [[] for _ in range(n)]↵
lo,hi,mask = 0,m,0b11111111111111111↵

def valid(loc, mid):↵
component = [-1]*n↵
numcomps = 0↵
for i in range(n):↵
if component[i] < 0:↵
todo = [i]↵
component[i] = numcomps↵
while todo:↵
for child in [x[0] for x in edges[todo.pop()] if component[x[0]] < 0 and x[1] < mid]:↵
component[child] = numcomps↵
todo.append(child)↵
numcomps += 1↵
if component[i] != component[loc[i] - 1]:↵
return False↵
return True↵

# bitwise to avoid tuple sort↵
all_edges = [*map(lambda x: int(x[2]) << 34 ^ int(x[0]) << 17 ^ int(x[1]), ↵
  map(methodcaller("split", b" "), lines[2:]))]↵
all_edges.sort(reverse=True)↵

for i, val in enumerate(all_edges):↵
rhs = (val & mask) - 1↵
lhs = ((val >> 17) & mask) - 1↵
edges[lhs].append((rhs,i))↵
edges[rhs].append((lhs,i))↵

while lo != hi:↵
mid = (lo + hi) // 2↵
if valid(loc, mid):↵
hi = mid↵
else:↵
lo = mid+1↵

open("wormsort.out","w").write(f"{-1 if lo == 0 else all_edges[lo-1] >> 34}\n")↵

main()↵
~~~~~↵

Any tips or advice on how to speed this up would be greatly appreciated. Thank you!

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en3 English drugkeeper 2022-11-29 11:00:43 131 Tiny change: 'ng-lang)\nProblem:' -> 'ng-lang)\n\nProblem:'
en2 English drugkeeper 2022-11-29 10:58:33 0 (published)
en1 English drugkeeper 2022-11-29 10:57:56 2614 Initial revision (saved to drafts)