Hello,
I tried problem Segment Sum. The DP solution for this problem was pretty straightforward to me. I decided to implement it in Python (You may find the python code below in this post).
I used memorization technique, and the recursive function returns two numbers, i.e. my DP table maintains two values: current summation and number of ways under the mentioned criteria in our current state. After hours of debugging, I couldn't find any problem with my code. I decided to remove if dp[smaller][start][pos][mask][0] != -1: return dp[smaller][start][pos][mask]
to see if there is any issue with the DP table. Surprisingly, it output correct results when I removed these two lines. It seems there is something wrong with returning tuples (or array of size 2 here) from a recursive function in Python.
To make sure that the method is correct, I reimplemented it in C++ and it got Accepted, as expected: 60403673. Could you please help me fix the issue in Python?
Python code:
import sys
mod = 998244353
MAX_LENGTH = 20
bound = [0] * MAX_LENGTH
def mul(a, b): return (a * b) % mod
def add(a, b):
a += b
if a < 0: a += mod
if a >= mod: a -= mod
return a
def digitize(num):
for i in range(MAX_LENGTH):
bound[i] = num % 10
num //= 10
def rec(smaller, start, pos, mask):
global k
if bit_count[mask] > k:
return [0, 0]
if pos == -1:
return [0, 1]
# if the two following lines are removed, the code reutrns correct results
if dp[smaller][start][pos][mask][0] != -1:
return dp[smaller][start][pos][mask]
res_sum = res_ways = 0
for digit in range(0, 10):
if smaller == 0 and digit > bound[pos]:
continue
new_smaller = smaller | (digit < bound[pos])
new_start = start | (digit > 0) | (pos == 0)
new_mask = (mask | (1 << digit)) if new_start == 1 else 0
cur_sum, cur_ways = rec(new_smaller, new_start, pos - 1, new_mask)
res_sum = add(res_sum, add(mul(mul(digit, ten_pow[pos]), cur_ways), cur_sum))
res_ways = add(res_ways, cur_ways)
dp[smaller][start][pos][mask][0], dp[smaller][start][pos][mask][1] = res_sum, res_ways
return dp[smaller][start][pos][mask]
def solve(upper_bound):
global dp
dp = 2 * [2 * [MAX_LENGTH * [(1 << 10) * [[-1, -1]]]]]
digitize(upper_bound)
ans = rec(0, 0, MAX_LENGTH - 1, 0)
print(ans)
return ans[0]
inp = [int(x) for x in sys.stdin.read().split()]
l, r, k = inp[0], inp[1], inp[2]
bit_count = [0] * (1 << 10)
for i in range(1, 1 << 10): bit_count[i] = bit_count[i & (i - 1)] + 1
ten_pow = [(10 ** i) % mod for i in range(0, MAX_LENGTH)]
print(add(solve(r), -solve(l - 1)))