a permutation is valid only if |ai — i| != k for all 1<=i<=n. Count the number of valid permutations.
Constraints: 2 ≤ N ≤ 2000
1 ≤ K ≤ N − 1
EDIT: Thanks for the explanation from TAhmed33.
I implemented that explanation:
cpp code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<ll> vll;
typedef pair<ll,ll> pll;
typedef vector<vll> vvll;
#define forl(x,s,e) for(int x=s;x<e;x++)
#define all(x) x.begin(),x.end()
#define print(x) cout<<x<<'\n'
#define printv(x) for(auto c: x)cout<<c<<' ';cout<<'\n';
ll mod = 924844033;
ll MOD = 924844033;
ll N = 2101;
vll fact(N,1);
vll preprocess(ll n,ll k){
ll kk = k*2;
vvll dp;
forl(r,0,kk){
ll cnt = 0 ;
forl(j,0,n)
if(j%kk==r)
cnt+=1;
if(cnt==0) break;
vector<vvll> loc_dp(cnt+1,vvll(n+1,vll(3,0ll)));
forl(ind,0,cnt+1)
loc_dp[ind][0][0] = 1;
if(cnt >=1){
loc_dp[1][0][0] = 1;
loc_dp[1][1][1] = r-k >=0 ? 1:0;
loc_dp[1][1][2] = r+k < n ? 1:0;
}
forl(ind,2,cnt+1)
forl(m,1,n+1){
ll index = r+(ind-1)*kk;
loc_dp[ind][m][0] = (loc_dp[ind-1][m][0] + loc_dp[ind-1][m][1] + loc_dp[ind-1][m][2])%MOD;
if(index+k<n)
loc_dp[ind][m][2] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1] + loc_dp[ind-1][m-1][2])%MOD;
if(index-k>=0)
loc_dp[ind][m][1] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1])%MOD;
}
vll val;
forl(m,0,n+1)
val.push_back(loc_dp[cnt][m][0]+loc_dp[cnt][m][1]+loc_dp[cnt][m][2]);
dp.push_back(val) ;
}
ll sz = dp.size();
vvll fdp(sz+1,vll(n+1,0ll));
fdp[0][0] = 1;
// Time complexity of the next few lines = O(sz*n*n//sz) the break statement will activate after part = n//sz
// as each residue group will have roughly n//sz elements in it
forl(ind,1,sz+1)
forl(m,0,n+1)
forl(part,0,m+1){
ll val1 = (dp[ind-1][part]*fdp[ind-1][m-part])%MOD;
if(dp[ind-1][part]==0)break;// this break statement is crucial to prevent TLE
fdp[ind][m] = (fdp[ind][m] + val1 )%MOD;
}
return fdp[sz];
}
int main() {
forl(i,2,N)
fact[i] = (fact[i-1]*i)%mod;
ll n,k;
cin>>n>>k;
vll dp = preprocess(n,k);
ll ans = fact[n];
forl(i,1,n+1){
ll val = (dp[i]*fact[n-i])%mod;
if(i%2==0)
ans = (ans + val)%mod;
else
ans = (ans + mod - val)%mod;
}
print(ans);
return 0;
}