After trying many implementations of Suffix Tree and Suffix Array in Python, the fastest I managed to get was based on adamant Ukkonen version: https://codeforces.net/blog/entry/16780
Here is my Python version enhanced with memoization (I tested it with UVA 10679 — I Love Strings!!), happy to receive feedback on how to improve the algorithm:
from sys import stdin, stdout, stderr, setrecursionlimit
from functools import lru_cache
setrecursionlimit(100000)
def read():
return stdin.readline().rstrip()
def readint():
return int(read())
def make_node(_pos, _len):
global s, n, sz, to, link, fpos, slen, pos, node
fpos[sz] = _pos
slen[sz] = _len
sz += 1
return sz-1
def go_edge():
global s, n, sz, to, link, fpos, slen, pos, node
while (pos > slen[to[node].get(s[n - pos], 0)]):
node = to[node].get(s[n - pos], 0)
pos -= slen[node]
def add_letter(c):
global s, n, sz, to, link, fpos, slen, pos, node
s[n] = c
n += 1
pos += 1
last = 0
while(pos > 0):
go_edge()
edge = s[n - pos]
v = to[node].get(edge, 0)
t = s[fpos[v] + pos - 1]
if (v == 0):
to[node][edge] = make_node(n - pos, inf)
link[last] = node
last = 0
elif (t == c):
link[last] = node
return
else:
u = make_node(fpos[v], pos - 1)
to[u][c] = make_node(n - 1, inf)
to[u][t] = v
fpos[v] += pos - 1
slen[v] -= pos - 1
to[node][edge] = u
link[last] = u
last = u
if(node == 0):
pos -= 1
else:
node = link[node]
def init_tree(st):
global slen, ans, inf, maxn, s, to, fpos, slen, link, node, pos, sz, n
inf = int(1e9)
maxn = len(st)*2+1 #int(1e6+1)
s = [0]*maxn
to = [{} for i in range(maxn)]
fpos, slen, link = [0]*maxn, [0]*maxn, [0]*maxn
node, pos = 0, 0
sz = 1
n = 0
slen[0] = inf
ans = 0
for c in st:
add_letter(ord(c))
def traverse_edge(st, idx, start, end):
global len_text, len_st
k = start
while k <= end and k < len_text and idx < len_st:
if text[k] != st[idx]:
return -1
k += 1
idx += 1
if idx == len_st:
return idx
return 0
def edgelen(v, init, e):
if(v == 0):
return 0
return e-init+1
@lru_cache(maxsize=10000001)
def traverse(v, st, idx):
global len_st
r = -1
init = fpos[v]
end = fpos[v]+slen[v]
e = end-1
if v != 0:
r = traverse_edge(st, idx, init, e)
if r != 0:
if r == -1:
return []
return [r]
idx = idx + edgelen(v, init, e)
if idx > len_st:
return []
k = ord(st[idx])
children = to[v]
if k in children:
vv = children.get(k, 0)
return traverse(vv, st, idx)
return []
@lru_cache(maxsize=1001*10)
def solve(T, query):
traverse.cache_clear()
return "y\n" if traverse(0, query, 0) else "n\n"
def main():
global text, len_st, len_text
k = readint()
for ki in range(k):
text = read()+"$"
len_text = len(text)
init_tree(text)
q = readint()
for qi in range(q):
query = read()
len_st = len(query)
stdout.write(solve(text, query))
main()