a permutation is valid only if |ai — i| != k for all 1<=i<=n. Count the number of valid permutations.↵
↵
↵
Constraints:↵
2 ≤ N ≤ 2000↵
↵
↵
1 ≤ K ≤ N − 1↵
↵
EDIT: Thanks for the explanation from [user:methanol,2024-11-09].↵
↵
I implemented that explanation: ↵
↵
<spoiler summary="cpp code">↵
~~~~~↵
#include<bits/stdc++.h>↵
using namespace std;↵
↵
typedef long long ll;↵
typedef vector<ll> vll;↵
typedef pair<ll,ll> pll;↵
typedef vector<vll> vvll;↵
#define forl(x,s,e) for(int x=s;x<e;x++)↵
#define all(x) x.begin(),x.end()↵
#define print(x) cout<<x<<'\n'↵
#define printv(x) for(auto c: x)cout<<c<<' ';cout<<'\n';↵
↵
ll mod = 924844033;↵
ll MOD = 924844033;↵
ll N = 2101;↵
vll fact(N,1);↵
↵
vll preprocess(ll n,ll k){↵
↵
ll kk = k*2;↵
vvll dp;↵
↵
forl(r,0,kk){↵
ll cnt = 0 ; ↵
forl(j,0,n)↵
if(j%kk==r)↵
cnt+=1;↵
if(cnt==0) break;↵
vector<vvll> loc_dp(cnt+1,vvll(n+1,vll(3,0ll)));↵
↵
forl(ind,0,cnt+1)↵
loc_dp[ind][0][0] = 1;↵
if(cnt >=1){↵
loc_dp[1][0][0] = 1; ↵
loc_dp[1][1][1] = r-k >=0 ? 1:0;↵
loc_dp[1][1][2] = r+k < n ? 1:0;↵
}↵
↵
forl(ind,2,cnt+1)↵
forl(m,1,n+1){ ↵
ll index = r+(ind-1)*kk;↵
loc_dp[ind][m][0] = (loc_dp[ind-1][m][0] + loc_dp[ind-1][m][1] + loc_dp[ind-1][m][2])%MOD;↵
if(index+k<n)↵
loc_dp[ind][m][2] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1] + loc_dp[ind-1][m-1][2])%MOD;↵
if(index-k>=0)↵
loc_dp[ind][m][1] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1])%MOD; ↵
}↵
↵
vll val;↵
forl(m,0,n+1)↵
val.push_back(loc_dp[cnt][m][0]+loc_dp[cnt][m][1]+loc_dp[cnt][m][2]);↵
dp.push_back(val) ; ↵
}↵
ll sz = dp.size();↵
vvll fdp(sz+1,vll(n+1,0ll));↵
fdp[0][0] = 1;↵
↵
// Time complexity of the next few lines = O(sz*n*n//sz) the break statement will activate after part = n//sz ↵
// as each residue group will have roughly n//sz elements in it↵
forl(ind,1,sz+1)↵
forl(m,0,n+1)↵
forl(part,0,m+1){↵
ll val1 = (dp[ind-1][part]*fdp[ind-1][m-part])%MOD;↵
if(dp[ind-1][part]==0)break;// this break statement is crucial to prevent TLE↵
fdp[ind][m] = (fdp[ind][m] + val1 )%MOD;↵
}↵
return fdp[sz]; ↵
}↵
↵
↵
int main() {↵
forl(i,2,N)↵
fact[i] = (fact[i-1]*i)%mod;↵
ll n,k;↵
cin>>n>>k;↵
↵
vll dp = preprocess(n,k);↵
ll ans = fact[n]; ↵
forl(i,1,n+1){↵
ll val = (dp[i]*fact[n-i])%mod;↵
if(i%2==0)↵
ans = (ans + val)%mod; ↵
else↵
ans = (ans + mod - val)%mod; ↵
}↵
print(ans);↵
return 0;↵
}↵
↵
~~~~~↵
</spoiler>↵
↵
↵
↵
↵
Constraints:↵
2 ≤ N ≤ 2000↵
↵
↵
1 ≤ K ≤ N − 1↵
↵
EDIT: Thanks for the explanation from [user:methanol,2024-11-09].↵
↵
I implemented that explanation: ↵
↵
<spoiler summary="cpp code">↵
~~~~~↵
#include<bits/stdc++.h>↵
using namespace std;↵
↵
typedef long long ll;↵
typedef vector<ll> vll;↵
typedef pair<ll,ll> pll;↵
typedef vector<vll> vvll;↵
#define forl(x,s,e) for(int x=s;x<e;x++)↵
#define all(x) x.begin(),x.end()↵
#define print(x) cout<<x<<'\n'↵
#define printv(x) for(auto c: x)cout<<c<<' ';cout<<'\n';↵
↵
ll mod = 924844033;↵
ll MOD = 924844033;↵
ll N = 2101;↵
vll fact(N,1);↵
↵
vll preprocess(ll n,ll k){↵
↵
ll kk = k*2;↵
vvll dp;↵
↵
forl(r,0,kk){↵
ll cnt = 0 ; ↵
forl(j,0,n)↵
if(j%kk==r)↵
cnt+=1;↵
if(cnt==0) break;↵
vector<vvll> loc_dp(cnt+1,vvll(n+1,vll(3,0ll)));↵
↵
forl(ind,0,cnt+1)↵
loc_dp[ind][0][0] = 1;↵
if(cnt >=1){↵
loc_dp[1][0][0] = 1; ↵
loc_dp[1][1][1] = r-k >=0 ? 1:0;↵
loc_dp[1][1][2] = r+k < n ? 1:0;↵
}↵
↵
forl(ind,2,cnt+1)↵
forl(m,1,n+1){ ↵
ll index = r+(ind-1)*kk;↵
loc_dp[ind][m][0] = (loc_dp[ind-1][m][0] + loc_dp[ind-1][m][1] + loc_dp[ind-1][m][2])%MOD;↵
if(index+k<n)↵
loc_dp[ind][m][2] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1] + loc_dp[ind-1][m-1][2])%MOD;↵
if(index-k>=0)↵
loc_dp[ind][m][1] = (loc_dp[ind-1][m-1][0] + loc_dp[ind-1][m-1][1])%MOD; ↵
}↵
↵
vll val;↵
forl(m,0,n+1)↵
val.push_back(loc_dp[cnt][m][0]+loc_dp[cnt][m][1]+loc_dp[cnt][m][2]);↵
dp.push_back(val) ; ↵
}↵
ll sz = dp.size();↵
vvll fdp(sz+1,vll(n+1,0ll));↵
fdp[0][0] = 1;↵
↵
// Time complexity of the next few lines = O(sz*n*n//sz) the break statement will activate after part = n//sz ↵
// as each residue group will have roughly n//sz elements in it↵
forl(ind,1,sz+1)↵
forl(m,0,n+1)↵
forl(part,0,m+1){↵
ll val1 = (dp[ind-1][part]*fdp[ind-1][m-part])%MOD;↵
if(dp[ind-1][part]==0)break;// this break statement is crucial to prevent TLE↵
fdp[ind][m] = (fdp[ind][m] + val1 )%MOD;↵
}↵
return fdp[sz]; ↵
}↵
↵
↵
int main() {↵
forl(i,2,N)↵
fact[i] = (fact[i-1]*i)%mod;↵
ll n,k;↵
cin>>n>>k;↵
↵
vll dp = preprocess(n,k);↵
ll ans = fact[n]; ↵
forl(i,1,n+1){↵
ll val = (dp[i]*fact[n-i])%mod;↵
if(i%2==0)↵
ans = (ans + val)%mod; ↵
else↵
ans = (ans + mod - val)%mod; ↵
}↵
print(ans);↵
return 0;↵
}↵
↵
~~~~~↵
</spoiler>↵
↵
↵