Hi, Codeforces!
1529A - Eshag любит большие массивы
problem idea and solution: AmShZ
editorial
Tutorial is loading...
official solution
// khodaya khodet komak kon
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair <int, int> pii;
typedef pair <pii, int> ppi;
typedef pair <int, pii> pip;
typedef pair <pii, pii> ppp;
typedef pair <ll, ll> pll;
# define A first
# define B second
# define endl '\n'
# define sep ' '
# define all(x) x.begin(), x.end()
# define kill(x) return cout << x << endl, 0
# define SZ(x) int(x.size())
# define lc id << 1
# define rc id << 1 | 1
ll power(ll a, ll b, ll md) {return (!b ? 1 : (b & 1 ? a * power(a * a % md, b / 2, md) % md : power(a * a % md, b / 2, md) % md));}
const int xn = 1e2 + 10;
const int sq = 320;
const int inf = 1e9 + 10;
const ll INF = 1e18 + 10;
const int mod = 1e9 + 7;//998244353;
const int base = 257;
int qq, n, a[xn], mn, ans;
int main(){
ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);
cin >> qq;
while (qq --){
cin >> n, mn = inf, ans = 0;
for (int i = 1; i <= n; ++ i)
cin >> a[i], mn = min(mn, a[i]);
for (int i = 1; i <= n; ++ i)
ans += a[i] != mn;
cout << ans << endl;
}
return 0;
}
1529B - Sifid и странные подпоследовательности
problem idea and solution: Davoth
editorial
Tutorial is loading...
official solution
// khodaya khodet komak kon
// Nightcall - London Grammer
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair <int, int> pii;
typedef pair <pii, int> ppi;
typedef pair <int, pii> pip;
typedef pair <pii, pii> ppp;
typedef pair <ll, ll> pll;
# define A first
# define B second
# define endl '\n'
# define sep ' '
# define all(x) x.begin(), x.end()
# define kill(x) return cout << x << endl, 0
# define SZ(x) int(x.size())
# define lc id << 1
# define rc id << 1 | 1
ll power(ll a, ll b, ll md) {return (!b ? 1 : (b & 1 ? a * power(a * a % md, b / 2, md) % md : power(a * a % md, b / 2, md) % md));}
const int xn = 1e5 + 10;
const int xm = - 20 + 10;
const int sq = 320;
const int inf = 1e9 + 10;
const ll INF = 1e18 + 10;
const int mod = 1e9 + 7;//998244353;
const int base = 257;
int qq, n, a[xn], ans, mn;
bool flag;
int main(){
ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);
cin >> qq;
while (qq --){
cin >> n, ans = 0;
for (int i = 1; i <= n; ++ i)
cin >> a[i], ans += (a[i] <= 0);
sort(a + 1, a + n + 1), mn = inf;
for (int i = 1; i <= n; ++ i)
if (a[i] > 0)
mn = min(mn, a[i]);
flag = (mn < inf);
for (int i = 2; i <= n; ++ i)
if (a[i] <= 0)
flag &= (a[i] - a[i - 1] >= mn);
if (flag)
cout << ans + 1 << endl;
else
cout << ans << endl;
}
return 0;
}
problem idea and solution: shokouf, AmShZ
editorial
Tutorial is loading...
official solution
// Call my Name and Save me from The Dark
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef pair<int, int> pii;
#define SZ(x) (int) x.size()
#define F first
#define S second
const int N = 2e5 + 10;
ll dp[2][N]; int A[2][N], n; vector<int> adj[N];
void DFS(int v, int p = -1) {
dp[0][v] = dp[1][v] = 0;
for (int u : adj[v]) {
if (u == p) continue;
DFS(u, v);
dp[0][v] += max(abs(A[0][v] - A[1][u]) + dp[1][u], dp[0][u] + abs(A[0][v] - A[0][u]));
dp[1][v] += max(abs(A[1][v] - A[1][u]) + dp[1][u], dp[0][u] + abs(A[1][v] - A[0][u]));
}
}
void Solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d%d", &A[0][i], &A[1][i]);
fill(adj + 1, adj + n + 1, vector<int>());
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
DFS(1);
printf("%lld\n", max(dp[0][1], dp[1][1]));
}
int main() {
int t; scanf("%d", &t);
while (t--) Solve();
return 0;
}
1528B - Kavi разбивает на пары
problem idea and solution: alireza_kaviani
editorial
Tutorial is loading...
official solution
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define X first
#define Y second
#define endl '\n'
const int N = 1e6 + 10;
const int MOD = 998244353;
int n, dp[N], S;
int main() {
ios_base::sync_with_stdio(false); cin.tie(nullptr);
cin >> n;
for (int i = 1; i <= n; i++) {
for (int j = i + i; j <= n; j += i) {
dp[j]++;
}
}
dp[0] = S = 1;
for (int i = 1; i <= n; i++) {
dp[i] = (dp[i] + S) % MOD;
S = (S + dp[i]) % MOD;
}
cout << dp[n] << endl;
return 0;
}
problem idea and solution: AliShahali1382
editorial
Tutorial is loading...
official solution
#include <bits/stdc++.h>
#pragma GCC optimize ("O2,unroll-loops")
//#pragma GCC optimize("no-stack-protector,fast-math")
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef pair<ll, ll> pll;
#define debug(x) cerr<<#x<<'='<<(x)<<endl;
#define debugp(x) cerr<<#x<<"= {"<<(x.first)<<", "<<(x.second)<<"}"<<endl;
#define debug2(x, y) cerr<<"{"<<#x<<", "<<#y<<"} = {"<<(x)<<", "<<(y)<<"}"<<endl;
#define debugv(v) {cerr<<#v<<" : ";for (auto x:v) cerr<<x<<' ';cerr<<endl;}
#define all(x) x.begin(), x.end()
#define pb push_back
#define kill(x) return cout<<x<<'\n', 0;
const int inf=1000000010;
const ll INF=1000000000000001000LL;
const int mod=1000000007;
const int MAXN=300010, LOG=20;
int n, m, k, u, v, x, y, t, a, b, ans;
int par1[MAXN], par2[MAXN];
int stt[MAXN], fnt[MAXN], timer;
vector<int> G1[MAXN], G2[MAXN];
set<pii> st;
int Add(int v){
auto it=st.lower_bound({stt[v], v});
if (it!=st.end() && fnt[it->second]<=fnt[v]) return -2;
if (it==st.begin()) return -1;
--it;
int res=it->second;
if (fnt[v]>fnt[res]) return -1;
st.erase(it);
return res;
}
void dfs1(int node){
stt[node]=timer++;
for (int v:G2[node]) dfs1(v);
fnt[node]=timer;
}
void dfs2(int node){
int res=Add(node);
if (res!=-2) st.insert({stt[node], node});
ans=max(ans, (int)st.size());
for (int v:G1[node]) dfs2(v);
if (res!=-2){
st.erase({stt[node], node});
if (res!=-1) st.insert({stt[res], res});
}
}
void Solve(){
cin>>n;
for (int i=1; i<=n; i++) G1[i].clear(), G2[i].clear();
for (int i=2; i<=n; i++) cin>>par1[i], G1[par1[i]].pb(i);
for (int i=2; i<=n; i++) cin>>par2[i], G2[par2[i]].pb(i);
timer=1;
ans=0;
dfs1(1);
dfs2(1);
cout<<ans<<"\n";
}
int main(){
ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
int T;
cin>>T;
while (T--) Solve();
return 0;
}
better implementation of official solution: AaParsa
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define X first
#define Y second
#define endl '\n'
const int N = 3e5 + 10;
int t, n, L[N], R[N], timer, now, ans;
vector<int> Srsh[N], Msht[N];
set<pii> st;
void MshtDFS(int u) {
L[u] = timer++;
for (int v : Msht[u]) {
MshtDFS(v);
}
R[u] = timer;
}
int ispar(int u, int v) {
return L[u] <= L[v] && R[v] <= R[u];
}
void SrshDFS(int u) {
int last = now;
// add u
auto it = st.lower_bound({L[u], 0});
if (it != st.end())
now += 1 - ispar(u, it->Y);
if (it != st.begin()) {
auto nxt = it--;
now += 1 - ispar(it->Y, u);
if (nxt != st.end())
now -= 1 - ispar(it->Y, nxt->Y);
}
st.insert({L[u], u});
ans = max(ans, now);
// DFS
for (int v : Srsh[u]) {
SrshDFS(v);
}
// remove u
st.erase({L[u], u});
now = last;
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(nullptr);
cin >> t;
while (t --> 0) {
cin >> n;
for (int i = 1; i <= n; i++) {
Srsh[i].clear();
Msht[i].clear();
}
for (int u = 2; u <= n; u++) {
int par; cin >> par;
Srsh[par].push_back(u);
}
for (int u = 2; u <= n; u++) {
int par; cin >> par;
Msht[par].push_back(u);
}
timer = 0;
MshtDFS(1);
ans = now = 0;
SrshDFS(1);
cout << ans + 1 << endl;
}
return 0;
}
python solution: Tet
#!/usr/bin/env python
import os
import sys
from io import BytesIO, IOBase
# region fastio
BUFSIZE = 8192
class FastIO(IOBase):
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
self.write = lambda s: self.buffer.write(s.encode("ascii"))
self.read = lambda: self.buffer.read().decode("ascii")
self.readline = lambda: self.buffer.readline().decode("ascii")
sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")
class FenwickTree:
def __init__(self, x):
"""transform list into BIT"""
self.bit = x
for i in range(len(x)):
j = i | (i + 1)
if j < len(x):
x[j] += x[i]
def update(self, idx, x):
"""updates bit[idx] += x"""
while idx < len(self.bit):
self.bit[idx] += x
idx |= idx + 1
def query(self, end):
"""calc sum(bit[:end])"""
end += 1
x = 0
while end:
x += self.bit[end - 1]
end &= end - 1
return x
def findkth(self, k):
"""Find largest idx such that sum(bit[:idx]) <= k"""
idx = -1
for d in reversed(range(len(self.bit).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(self.bit) and k >= self.bit[right_idx]:
idx = right_idx
k -= self.bit[idx]
return idx + 1
g1 = []
g2 = []
st = []
ft = []
def dfs1(graph, start=0):
n = len(graph)
tim = 1
visited = [False] * n
stack = [start]
while stack:
start = stack[-1]
if not visited[start]:
st[start] = tim
tim = tim + 1
visited[start] = True
for child in graph[start]:
if not visited[child]:
stack.append(child)
else:
stack.pop()
ft[start] = tim
def Do(v):
global ans , cnt , mem
fen.update(st[v] , 1);
if(fen.query(ft[v] - 1) - fen.query(st[v])):
mem += [[v , 0 , 0]]
return;
u = fen2.query(st[v]);
ok[v] = 1
cnt += 1
fen2.update(st[v] , v - u)
fen2.update(ft[v] , u - v)
mem += [[v , u , ok[u]]]
cnt -= ok[u]
ok[u] = 0
return;
def undo():
global cnt , ans , mem
[v , u , a] = mem[-1]
if(a != 0):
ok[u] = 1
cnt += 1
cnt -= 1
ok[v] = 0
fen.update(st[v] , -1)
fen2.update(st[v] , u - v)
fen2.update(ft[v] , v - u)
mem.pop()
return;
def dfs(graph, start=0):
global ans , cnt
n = len(graph)
visited = [False] * n
stack = [start]
while stack:
start = stack[-1]
if not visited[start]:
Do(start)
ans = max(ans , cnt)
visited[start] = True
for child in graph[start]:
if not visited[child]:
stack.append(child)
else:
stack.pop()
undo()
def solve():
global n, g1 , g2 , V , st , ft , fen , fen2 , mem , ok , cnt , ans
n = int(input())
if(n == 1):
print(1)
exit(0)
g1 = [] ; g2 = []
for i in range(n):
g1.append([])
g2.append([])
# 0 base
V = list(map(int , input().split()))
for i in range ( 1 , n ):
v = V[i - 1]-1
g1[v].append(i)
V = list(map(int , input().split()))
for i in range ( 1 , n ):
v = V[i-1]-1
g2[v].append(i)
st = [0] * (n + 100)
ft = [0] * (n + 100)
fen = FenwickTree([0] * (n + 3))
fen2 = FenwickTree([0] * (n + 3))
mem = []
ok = [0]*(n + 100)
cnt = 0
ans = 0
dfs1(g2)
dfs(g1)
print(ans)
q = int(input())
for i in range (q):
solve()
another O(n.log) solution with seg-tree: N.N_2004
#include<bits/stdc++.h>
#define lc (id * 2)
#define rc (id * 2 + 1)
#define md ((s + e) / 2)
#define dm ((s + e) / 2 + 1)
#define ln (e - s + 1)
using namespace std;
typedef long long ll;
const ll MXN = 3e5 + 10;
const ll MXS = MXN * 4;
ll n, timer, ans, Ans;
ll lazy[MXS], seg[MXS];
ll Stm[MXN], Ftm[MXN];
vector<ll> adj[MXN][2];
void Build(ll id = 1, ll s = 1, ll e = n){
lazy[id] = -1, seg[id] = 0;
if(ln > 1) Build(lc, s, md), Build(rc, dm, e);
}
void Shift(ll id, ll s, ll e){
if(lazy[id] == -1) return;
seg[id] = lazy[id];
if(ln > 1) lazy[lc] = lazy[rc] = lazy[id];
lazy[id] = -1;
}
void Upd(ll l, ll r, ll x, ll id = 1, ll s = 1, ll e = n){
Shift(id, s, e);
if(e < l || s > r) return;
if(l <= s && e <= r){
lazy[id] = x; Shift(id, s, e);
return;
}
Upd(l, r, x, lc, s, md), Upd(l, r, x, rc, dm, e);
seg[id] = max(seg[lc], seg[rc]);
}
ll Get(ll l, ll r, ll id = 1, ll s = 1, ll e = n){
Shift(id, s, e);
if(e < l || s > r) return 0;
if(l <= s && e <= r) return seg[id];
return max(Get(l, r, lc, s, md), Get(l, r, rc, dm, e));
}
void prep(ll u, ll par, ll f){
Stm[u] = ++ timer;
for(auto v : adj[u][f]){
if(v != par) prep(v, u, f);
}
Ftm[u] = timer;
}
bool Is_Jad(ll j, ll u){
return (Stm[j] <= Stm[u] && Ftm[u] <= Ftm[j]);
}
void DFS(ll u, ll par, ll f){
ll j = Get(Stm[u], Ftm[u]);
if(!j) Upd(Stm[u], Ftm[u], u), ans ++;
else{
if(Is_Jad(j, u)){
Upd(Stm[j], Ftm[j], 0);
Upd(Stm[u], Ftm[u], u);
}
}
Ans = max(Ans, ans);
for(auto v : adj[u][f]){
if(v != par) DFS(v, u, f);
}
if(!j) Upd(Stm[u], Ftm[u], 0), ans --;
else{
if(Is_Jad(j, u)){
Upd(Stm[u], Ftm[u], 0);
Upd(Stm[j], Ftm[j], j);
}
}
}
void solve(){
cin >> n, timer = ans = Ans = 0;
for(int i = 1; i <= n; i ++) adj[i][0].clear(), adj[i][1].clear();
for(int t = 0; t < 2; t ++) for(int u = 2, v; u <= n; u ++){
cin >> v, adj[v][t].push_back(u), adj[u][t].push_back(v);
}
Build(), prep(1, 0, 1);
DFS(1, 0, 0);
cout << Ans << '\n';
}
int main(){
ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);
ll q; cin >> q;
while(q --) solve();
return 0;
}
/*!
HE'S AN INSTIGATOR,
ENEMY ELIMINATOR,
AND WHEN HE KNOCKS YOU BETTER
YOU BETTER LET HIM IN.
*/
//! N.N
1528D - Это птица! Нет, это самолет! Нет, это AaParsa!
problem idea and solution: AmShZ, AliShahali1382
editorial
Tutorial is loading...
official solution
#include <bits/stdc++.h>
#pragma GCC optimize ("O2,unroll-loops")
//#pragma GCC optimize("no-stack-protector,fast-math")
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef pair<ll, ll> pll;
#define debug(x) cerr<<#x<<'='<<(x)<<endl;
#define debugp(x) cerr<<#x<<"= {"<<(x.first)<<", "<<(x.second)<<"}"<<endl;
#define debug2(x, y) cerr<<"{"<<#x<<", "<<#y<<"} = {"<<(x)<<", "<<(y)<<"}"<<endl;
#define debugv(v) {cerr<<#v<<" : ";for (auto x:v) cerr<<x<<' ';cerr<<endl;}
#define all(x) x.begin(), x.end()
#define pb push_back
#define kill(x) return cout<<x<<'\n', 0;
const int inf=1000000010;
const ll INF=1000000000000001000LL;
const int mod=1000000007;
const int N=605;
int n, m, k, u, v, x, y, t, a, b;
bool mark[N];
int G[N][N], D[N];
inline void upd(int &x, int y){ if (x>y) x=y;}
int main(){
ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
memset(G, 63, sizeof(G));
cin>>n>>m;
while (m--){
cin>>u>>v>>x;
G[u][v]=min(G[u][v], x);
}
for (int i=0; i<n; i++){
memset(D, 63, sizeof(D));
memset(mark, 0, sizeof(mark));
for (int v=0; v<n; v++) upd(D[v], G[i][v]);
while (1){
int v=-1;
for (int x=0; x<n; x++) if (!mark[x]){
if (v==-1 || D[v]>D[x]) v=x;
}
if (v==-1) break ;
mark[v]=1;
upd(D[(v+1)%n], D[v]+1);
for (int u=0; u<n; u++)
upd(D[(u+D[v])%n], D[v]+G[v][u]);
}
for (int j=0; j<n; j++){
if (i==j) cout<<"0 ";
else cout<<D[j]<<" ";
}
cout<<"\n";
}
return 0;
}
a different O(n^3) solution: Atreus
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 600 + 5;
int n;
int dp[maxn];
bool mark[maxn];
vector<pair<int,int>> g[maxn];
int go[maxn];
int dis(int s, int t) {
if (s <= t)
return t - s;
return n - (s - t);
}
void dijkstra(int src) {
memset(dp, 63, sizeof dp);
memset(mark, 0, sizeof mark);
dp[src] = 0;
set<int> S;
for (int i = 0; i < n; i++)
S.insert(i);
for (int i = 0; i < n - 1; i++) {
int v = -1;
for (int j = 0; j < n; j++) {
if (mark[j])
continue;
if (v == -1 or dp[v] > dp[j])
v = j;
}
S.erase(v);
mark[v] = 1;
if (v != src) {
auto it = S.lower_bound(v);
if (it == S.end())
it = S.lower_bound(0);
int nex = *it;
dp[nex] = min(dp[nex], dp[v] + dis(v, nex));
}
for (int i = 2 * n - 1; i >= 0; i--) {
int v = i % n;
if (!mark[v])
go[v] = v;
else
go[v] = go[(v + 1) % n];
}
for (auto [u, w] : g[v]) {
int nex = go[(u + dp[v]) % n];
dp[nex] = min(dp[nex], dp[v] + w + dis((u + dp[v]) % n, nex));
}
}
}
int main() {
ios_base::sync_with_stdio(false);
int m;
cin >> n >> m;
for (int i = 0; i < m; i++) {
int v, u, w;
cin >> v >> u >> w;
g[v].push_back({u, w});
}
for (int i = 0; i < n; i++) {
dijkstra(i);
for (int j = 0; j < n; j++)
cout << dp[j] << " \n"[j == n-1];
}
}
1528E - Mashtali и деревья Hagh
problem idea and solution: AliShahali1382
Thanks to Kininarimasu for the editorial.
editorial
Tutorial is loading...
official solution
#include <bits/stdc++.h>
#pragma GCC optimize ("O2,unroll-loops")
//#pragma GCC optimize("no-stack-protector,fast-math")
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef pair<ll, ll> pll;
#define debug(x) cerr<<#x<<'='<<(x)<<endl;
#define debugp(x) cerr<<#x<<"= {"<<(x.first)<<", "<<(x.second)<<"}"<<endl;
#define debug2(x, y) cerr<<"{"<<#x<<", "<<#y<<"} = {"<<(x)<<", "<<(y)<<"}"<<endl;
#define debugv(v) {cerr<<#v<<" : ";for (auto x:v) cerr<<x<<' ';cerr<<endl;}
#define all(x) x.begin(), x.end()
#define pb push_back
#define kill(x) return cout<<x<<'\n', 0;
const int inf=1000000010;
const ll INF=1000000000000001000LL;
const int mod=998244353, inv6=166374059;
const int MAXN=1000010, LOG=20;
ll n, m, k, u, v, x, y, t, a, b, ans;
ll dp[MAXN], ps[MAXN];
ll pd[MAXN], sp[MAXN];
ll C3(ll x){
return x*(x-1)%mod*(x-2)%mod*inv6%mod;
}
ll C2(ll x){
return x*(x-1)/2%mod;
}
int main(){
ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
cin>>n;
dp[0]=ps[0]=1;
pd[0]=sp[0]=1;
for (int i=1; i<=n; i++){
ps[i]=(1 + ps[i-1] + ps[i-1]*(ps[i-1]+1)/2)%mod;
dp[i]=ps[i]-ps[i-1];
pd[i]=dp[i]-dp[i-1];
sp[i]=(sp[i-1]+pd[i])%mod;
}
for (int i=0; i<n; i++) ans=(ans + pd[i]*sp[n-1-i])%mod;
ans=(ans + 2*C3(ps[n-1]+2))%mod;
if (n>=2) ans=(ans - 2*C3(ps[n-2]+2))%mod;
ans=(ans + 2*C2(ps[n-1]+1))%mod;
if (n>=2) ans=(ans - 2*C2(ps[n-2]+1))%mod;
if (ans<0) ans+=mod;
cout<<ans<<"\n";
return 0;
}
problem idea and solution: AmShZ, AliShahali1382
editorial
Tutorial is loading...
official solution
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair <int, int> pii;
typedef pair <pii, int> ppi;
typedef pair <int, pii> pip;
typedef pair <pii, pii> ppp;
typedef pair <ll, ll> pll;
# define A first
# define B second
# define endl '\n'
# define sep ' '
# define all(x) x.begin(), x.end()
# define kill(x) return cout << x << endl, 0
# define SZ(x) int(x.size())
# define lc id << 1
# define rc id << 1 | 1
# define InTheNameOfGod ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);
ll power(ll a, ll b, ll md) {return (!b ? 1 : (b & 1 ? a * power(a * a % md, b / 2, md) % md : power(a * a % md, b / 2, md) % md));}
const int xn = 3e5 + 10;
const int xm = 18;
const int sq = 320;
const int inf = 1e9 + 10;
const ll INF = 1e18 + 10;
const int mod = 998244353;
const int base = 257;
int n, k, C[xn], ans, A[xn], B[xn], rev[xn];
int fact[xn], ifact[xn];
void NTT(int *A, bool inv){
int n = (1 << xm);
for (int i = 0; i < (1 << xm); ++ i)
if (rev[i] < i)
swap(A[i], A[rev[i]]);
for (int ln = 1; ln < n; ln <<= 1){
int w = power(3, mod / 2 / ln, mod);
if (inv)
w = power(w, mod - 2, mod);
for (int i = 0; i < n; i += ln + ln){
int wn = 1;
for (int j = i; j < i + ln; ++ j){
int x = A[j], y = 1ll * A[j + ln] * wn % mod;
A[j] = (x + y) % mod;
A[j + ln] = (x - y + mod) % mod;
wn = 1ll * wn * w % mod;
}
}
}
if (inv){
int invn = power(1 << xm, mod - 2, mod);
for (int i = 0; i < n; ++ i)
A[i] = 1ll * A[i] * invn % mod;
}
}
int main(){
InTheNameOfGod;
cin >> n >> k;
C[0] = fact[0] = 1;
for (int i = 1; i < xn; ++ i){
C[i] = 1ll * C[i - 1] * (n - i + 1) % mod * power(i, mod - 2, mod) % mod;
fact[i] = 1ll * fact[i - 1] * i % mod;
}
ifact[xn - 1] = power(fact[xn - 1], mod - 2, mod);
for (int i = xn - 2; i >= 0; -- i)
ifact[i] = 1ll * ifact[i + 1] * (i + 1) % mod;
for (int i = 1; i < (1 << xm); ++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (xm - 1));
for (int i = 0; i <= k; ++ i){
A[i] = 1ll * power(i, k, mod) * ifact[i] % mod;
B[i] = ifact[i];
if (i % 2)
B[i] = (mod - B[i]) % mod;
}
NTT(A, 0), NTT(B, 0);
for (int i = 0; i < xn; ++ i)
A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, 1);
for (int i = 1; i <= k; ++ i)
ans = (ans + 1LL * A[i] * C[i] % mod * power(n + 1, n - i, mod) % mod * fact[i] % mod) % mod;
cout << ans << endl;
return 0;
}