Bring's blog

By Bring, history, 22 months ago, In English

For a better experience please click here.

Link to the question: Luogu, AtCoder

Preface

The very first generating function and polynomial problem solved in my life!

This blog is a detailed explanation and extension of the official editorial. I will try my best to explain the mathematical expressions and their deeper meanings so that you may understand if you are also new to generating functions and polynomials


Our Goal

Let $$$X$$$ be our random variable, which is the number of rolls after which all $$$N$$$-sides have shown up once for the first time. Its probability mass function $$$p_i=\mathbb P(X=i)$$$ is just the probability that all $$$N$$$-sides have shown up at exactly $$$i$$$-th roll.

Then, what we are looking for is

$$$\mathbb E(X)=\sum_{i=0}^\infty ip_i=0p_0+1p_1+2p_2+\cdots$$$

$$$\mathbb E(X^2)=\sum_{i=0}^\infty i^2p_i=0^2p_0+1^2p_1+2^2p_2+\cdots$$$

$$$\mathbb E(X^3)=\sum_{i=0}^\infty i^3p_i=0^3p_0+1^3p_1+2^3p_2+\cdots$$$

$$$\vdots$$$

$$$\mathbb E(X^M)=\sum_{i=0}^\infty i^Mp_i$$$

(Actually, $$$p_i=0$$$ if $$$i<n,$$$ but it doesn't matter.)


Derivation: Generating Functions


Ordinary and Exponential Generating Functions

For a sequence $$${a_i}_{i=0}^\infty,$$$ there are two formal power series associated with it:

  • Its Ordinary Generating Function (OGF) is
$$$f(x)=\sum_{i=0}^\infty a_ix^i.$$$

and we denote it $$${a_i}_{i=0}^\infty\xrightarrow{\text{OGF}}f(x).$$$ * Its Exponential Generating Function (EGF) is

$$$F(x)=\sum_{i=0}^\infty a_i\frac{x^i}{i!}.$$$

and we denote it $$${a_i}_{i=0}^\infty\xrightarrow{\text{EGF}}F(x).$$$

We can see that the EGF of $$${a_i}_{i=0}^\infty$$$ is just the OGF of $$${\frac{a_i}{i!}}_{i=0}^\infty.$$$

Probability Generating Functions

In particular, if the sequence $$${p_i}_{i=0}^\infty$$$ is the probability mass function of a discrete random variable $$$X$$$ taking non-negative integer values, then its OGF is also called the Probability Generating Function (PGF) of the random variable $$$X,$$$ written

$$$G_X(x)=\mathbb E(x^X)=\sum_{i=0}^\infty p_ix^i.$$$

Our first goal is to find the PGF of our random variable $$$X,$$$ and then we will show how to use that function to derive the final answer.


Finding the PGF of $$$X$$$

It is difficult to consider "**the first time** when all $$$N$$$-sides have shown", so we drop that condition. We continue rolling, not stopping when all $$$N$$$-sides have already shown up, and let $$$a_i$$$ be the probability that all $$$N$$$-sides have shown up after $$$i$$$ rolls.

Then, we have $$$p_i=a_i-a_{i-1}.$$$ This is because subtracting the former term is equivalent to subtracting the probability that all $$$N$$$-sides have shown up before the $$$i$$$-th roll, and the probability that all $$$N$$$-sides have shown up for the first time at exactly the $$$i$$$-th roll remains.

We try to find the OGF of $$${a_i}_{i=0}^\infty.$$$

(A subtlety: although $$$a_i$$$ stores the probability of something, its OGF is not a PGF because $$$a_i$$$ is not a probability mass function, but just a tool for us to find $$$p_i.$$$)

However, it is easier to find its EGF first than to find its OGF directly. This is due to the properties of products of OGFs and EGFs.


Products of OGFs and EGFs

Let $$${a_i}_{i=0}^\infty$$$ and $$${b_i}_{i=0}^\infty$$$ be sequences.

OGFs

Let $$$f(x)=\sum_{i=0}^\infty a_ix^i,g(x)=\sum_{i=0}^\infty b_ix^i$$$ be their OGFs, then its product

$$$f(x)g(x)=\left(\sum_{i=0}^\infty a_ix^i\right)\left(\sum_{i=0}^\infty b_ix^i\right)=\sum_{i=0}^\infty \left(\sum_{k=0}^i a_kb_{i-k}\right)x^i$$$

is the OGF of $$${c_i}_{i=0}^\infty,$$$ where $$$c_i=\sum_{k=0}^i a_kb_{i-k}.$$$

The way to understand its meaning is: let $$$a_i$$$ be the number of ways to take $$$i$$$ balls from a box, and $$$b_i$$$ be the number of ways to take $$$i$$$ balls from another box, then $$$c_i$$$ is the number of ways to take a total of $$$i$$$ balls from the two boxes.

Indeed, you can take $$$k$$$ balls from the first box, with $$$a_k$$$ ways, and $$$i-k$$$ balls from the second box, with $$$b_{i-k}$$$ ways. So, the number of ways to take $$$i$$$ balls from the first box and $$$i-k$$$ balls from the second box is $$$a_ib_{i-k},$$$ and you sum over all possible $$$k,$$$ which is from $$$0$$$ to $$$i.$$$

EGFs

Let $$$F(x)=\sum_{i=0}^\infty a_i\frac{x^i}{i!},G(x)=\sum_{i=0}^\infty b_i\frac{x^i}{i!}$$$ be their EGFs, then its product

$$$F(x)G(x)=\left(\sum_{i=0}^\infty a_i\frac{x^i}{i!}\right)\left(\sum_{i=0}^\infty b_i\frac{x^i}{i!}\right)=\sum_{i=0}^\infty \left(\sum_{k=0}^i \frac{i!}{k!(i-k)!}a_kb_{i-k}\right)\frac{x^i}{i!}$$$
$$$=\sum_{i=0}^\infty \left(\sum_{k=0}^i\binom{i}{k}a_kb_{i-k}\right)\frac{x^i}{i!}$$$

is the EGF of $$${d_i}_{i=0}^\infty,$$$ where $$$d_i=\sum_{k=0}^i \binom{i}{k}a_kb_{i-k}.$$$

The difference between the products of OGFs and EGFs is a binomial number. The way to understand its meaning is: let $$$a_i$$$ be the number of ways to take $$$i$$$ balls from a box and align them in some order, and $$$b_i$$$ be the number of ways to take $$$i$$$ balls from another box and align them in some order, then $$$c_i$$$ is the number of ways to take a total of $$$i$$$ balls from the two boxes and align them in some order.

Similarly, the number of ways to take $$$i$$$ balls from the first box in some order and $$$i-k$$$ balls from the second box in some order is $$$a_ib_{i-k}.$$$ Next, you have $$$\binom{i}{k}$$$ ways to choose $$$k$$$ slots from the $$$i$$$ slots for the balls from the first box. Thus, the total way to choose and align them is $$$\binom{i}{k}a_kb_{i-k}.$$$


When we roll the dice, we get a sequence of the side that shows up at each time, so there is an order. That's why it is easier to find the EGF of $$${a_i}_{i=0}^\infty.$$$

When we roll the dice once, each face shows up with probability $$$\frac{1}{N}.$$$ If we consider a specific side, for example, $$$1,$$$ then the probability of getting all $$$1$$$'s in $$$i$$$ rolls is $$$\frac{1}{N^i}.$$$ The EGF of the probability of getting all $$$1$$$'s in $$$i$$$ rolls is therefore

$$$\sum_{i=1}^\infty \frac{1}{N^i}\cdot\frac{x^i}{i!}=e^{\frac{x}{N}}-1.$$$

Note that we drop the case $$$i=0$$$ because we want that side to show up at least once.

Symmetrically, all $$$N$$$-sides have the same EGF. And the EGF of the probability of getting all $$$N$$$-sides in $$$i$$$ rolls is

$$$F(x)=\sum_{i=0}^\infty a_i\frac{x^i}{i!}=(e^{\frac{x}{N}}-1)^N.$$$

We are just taking the product of the EGF corresponding to each side. As they are EGFs, their product automatically deals with the order of the sides that show up.


An example

If the derivation above seems a bit non-intuitive, we may verify it with $$$N=2,$$$ a dice with two sides.

Trivially, $$$a_0=a_1=0.$$$

If we roll the dice twice, then $$$12,21$$$ are two ways that make both sides show up. There are in total $$$4$$$ equally possible results ($$$11,12,21,22$$$), so $$$a_2=\frac{2}{4}=\frac{1}{2}.$$$

If we roll the dice three times, then $$$112,121,211,221,212,122$$$ are the ways to get both sides showing up, so $$$a_3=\frac{6}{8}=\frac{3}{4}.$$$

Similarly, $$$a_4=\frac{14}{16}=\frac{7}{8}.$$$

Therefore, $$${a_i}_{i=0}^\infty \xrightarrow{\text{EGF}}F(x)=\sum_{i=0}^\infty a_i\frac{x^i}{i!}$$$

$$$=0+0x+\frac{1}{2!}\cdot \frac{1}{2}x^2+\frac{1}{3!}\cdot \frac{3}{4}x^3+\frac{1}{4!}\cdot\frac{7}{8}x^4+\cdots$$$

$$$=\frac{1}{4}x^2+\frac{1}{8}x^3+\frac{7}{192}x^4+\cdots.$$$

Using our formula, $$$(e^\frac{x}{2}-1)^2=(\frac{1}{2}x+\frac{1}{4}\cdot\frac{x^2}{2!}+\frac{1}{8}\cdot\frac{x^3}{3!}+\cdots)^2$$$

$$$=(\frac{1}{2}x+\frac{1}{8}x^2+\frac{1}{48}x^3+\cdots)^2$$$

$$$=\frac{1}{4}x^2+\frac{1}{8}x^3+\frac{7}{192}x^4+\cdots$$$

which matches with our "brute-forcely" calculated $$$F(x).$$$


Now that we have the EGF of $$${a_i},$$$ we convert it to its OGF.


Conversion between OGFs and EGFs

There are two laws:

  1. If
$$$f(x)\xleftarrow{\text{OGF}} \{a_i\}_{i=0}^\infty \xrightarrow{\text{EGF}} F(x)$$$

($$$f(x)$$$ and $$$F(x)$$$ are the OGF and EGF of the same sequence) and

$$$g(x)\xleftarrow{\text{OGF}} \{b_i\}_{i=0}^\infty \xrightarrow{\text{EGF}} G(x),$$$

then

$$$\lambda f(x)+\mu g(x)\xleftarrow{\text{OGF}} \{\lambda a_i+\mu b_i\}_{i=0}^\infty \xrightarrow{\text{EGF}}\lambda F(x)+\mu G(x).$$$

This law tells us there is a sense of 'linearity' between sequences and their GFs. When doing conversions, we can deal with separate terms and add them up.

  1. For all constant $$$k,$$$
$$$\frac{1}{1-kx}\xleftarrow{\text{OGF}}\{k^i\}_{i=0}^\infty \xrightarrow{\text{EGF}} e^{kx}.$$$

The OGF is a geometric series and the EGF is the exponential function.


With the two rules above, we have $$$F(x)=(e^\frac{x}{N}-1)^N$$$

$$$=\sum_{r=0}^N\binom{N}{r}(-1)^{N-r}e^{\frac{rx}{N}}\xleftarrow{\text{EGF}} {a_i}\xrightarrow{\text{OGF}}f(x)=\sum_{r=0}^N\binom{N}{r}(-1)^{N-r}\frac{1}{1-\frac{rx}{N}}.$$$

And finally, we compute the PGF of $$${p_i}_{i=0}^\infty,$$$ which is $$$g(x)=\sum_{i=0}^\infty p_ix^i$$$

$$$=a_0+\sum_{i=1}^\infty (a_i-a_{i-1})x^i$$$ (since $$$p_i=a_i-a_{i-1}$$$)

$$$=\sum_{i=0}^\infty a_ix^i-\sum_{i=0}^\infty a_ix^{i+1}$$$

$$$=f(x)-xf(x)$$$

$$$=(1-x)f(x)$$$

$$$=(1-x)\sum_{r=0}^N\binom{N}{r}(-1)^{N-r}\frac{1}{1-\frac{rx}{N}}.$$$

(Note: multiplying an OGF by $$$1-x$$$ is the same as subtracting each term in the sequence by its former term. On the other hand, its inverse action, multiplying by $$$\frac{1}{1-x},$$$ is the same as taking the prefix sum of each term.)

Though it is a 'nasty' formula, we will show later how to compute it in a code.

Spoil alert: there is a much easier derivation of $$$g(x)$$$ at the end of this blog.

Here is the final step: Calculating the expected value of $$$X,X^2,X^3,\cdots$$$ from the PGF.


Moment Generating Functions

Similar to PGF, the OGF of a probability mass function, the Moment Generating Function (MGF) is the EGF of something else.

The MGF of a random variable $$$X$$$ is

$$$M_X(x)=\mathbb E(e^{Xx})=\sum_{i=0}^\infty p_ie^{ix}.$$$

Here are some algebraic manipulations:

$$$M_X(x)=\sum_{i=0}^\infty p_ie^{ix}=\sum_{i=0}^\infty p_i\sum_{j=0}^\infty \frac{(ix)^j}{j!}=\sum_{j=0}^\infty \frac{x^j}{j!}\sum_{i=0}^\infty p_ii^j $$$
$$$=\sum_{j=0}^\infty \frac{x^j}{j!}\mathbb E(X^j),$$$

which is exactly the EGF of our answer!

(Note: actually the summation with expected values is a more general definition of MGF, as it can be defined for random variables that are not only taking values of non-negative integers.)

Lastly, for the random variable $$$X$$$ taking the value of non-negative integers, like in our problem, we have

$$$G_X(e^x)=M_X(x)$$$

by definition.


Therefore, our final goal is to find the coefficients up to $$$x^M$$$ of the MGF of $$$X,$$$ which is $$$g(e^x).$$$


Implementation: Convolutions

Prerequisites: Convolution and inverse series.

In the implementation, I used the class modint998244353 and convolution() from Atcoder Library for calculations in $$$\bmod 998244353$$$ and FFT.

For how FFT works and more, see this blog.

We do this by explicitly calculating the PGF $$$g(x),$$$ and then the MGF $$$g(e^x).$$$

Calculating $$$g(x)$$$

We have the explicit formula

$$$g(x)=(1-x)\sum_{r=0}^N\binom{N}{r}(-1)^{N-r}\frac{1}{1-\frac{rx}{N}}.$$$

The summation $$$\sum_{r=0}^N\binom{N}{r}(-1)^{N-r}\frac{1}{1-\frac{rx}{N}}$$$ can be written as a rational function $$$\frac{p(x)}{q(x)},$$$ with $$$p(x)$$$ and $$$q(x)$$$ each a polynomial with degree at most $$$N+1.$$$

As it is the sum of a bunch of fractions in the form $$$\frac{a}{1-bx},$$$ we may combine them in some order using convolution().

By FFT, the time complexity of multiplying two polynomials is $$$O(n\log n),$$$ where $$$n$$$ is the higher degree of the polynomials. So, the best way to combine the fractions is by Divide-and-Conquer: Divide the summations in half, calculate each half to get a rational function, and then combine the two rational functions.

Here is the class of rational functions and its addition method:

using mint=modint998244353; //calculation in mod 998244353
using ply=vector<mint>; //polynomials

struct R{ply p,q; //numerator and denominator
    R operator+(R b){
        ply rp(add(convolution(q,b.p),convolution(p,b.q))),
            rq(convolution(q,b.q));
        return{rp,rq};
    }
};

ply add(ply a,ply b){ //adding two polynomials
    if(a.size()<b.size())swap(a,b);
    Frn0(i,0,b.size())a[i]+=b[i];
    return a;
}

Here is the divide-and-conquer summation of rational functions, stored in vector<R>a.

R sum(vector<R>&a,int l,int r){ //summing from a[l] to a[r]
    if(l==r)return a[l];
    int md((l+r)/2);
    return sum(a,l,md)+sum(a,md+1,r);
}

The summation is done. For the remaining factor $$$1-x,$$$ there are two ways:

  1. Multiply it by the numerator. This can be done by subtracting each term by its former term. Note that the degree will increase by $$$1.$$$
  2. (used here) As the denominator already has a $$$1-x$$$ factor (check the summands), we can remove this factor by taking the prefix sum of each term, which is the same as multiplying $$$\frac{1}{1+x}=1+x+x^2+x^3+\cdots.$$$

And now, we obtain the PGF $$$g(x)$$$ as a rational function.

From $$$g(x)$$$ to $$$g(e^x)$$$

As $$$g(x)=\frac{p(x)}{q(x)}$$$ is a rational function. We calculate $$$p(e^x)$$$ and $$$q(e^x)$$$ separately and use inverse series to combine them. As we only need the coefficients from $$$x$$$ to $$$x^M,$$$ we may take the results $$$\bmod x^{M+1}.$$$

--- For a polynomial $$$P(x)=\sum_{i=0}^n c_ix^i, P(e^x)=\sum_{i=0}^n c_ie^{ix}.$$$

Using our trick of conversion between EGFs and OGFs again:

$$$\sum_{i=0}^n c_ie^{ix}\xleftarrow{\text{EGF}}\xrightarrow{\text{OGF}}\sum_{i=0}^n \frac{c_i}{1-ix}.$$$

So we may calculate the summation on the right hand side by the same Divide-and-Conquer technique. Use inverse series to get its coefficients in power series, and then divide the $$$i$$$-th term by $$$i!$$$ to obtain the left hand side.


The following is an implementation of inverse series $$$\bmod x^m.$$$

ply pinv(ply f,int m){
    ply g({f[0].inv()});
    f.resize(m);
    for(int s(2);s<2*m;s<<=1){
        auto tmp(convolution(convolution(g,g),
            ply(f.begin(),f.begin()+min(s,m))));
        g.resize(min(s,m));
        Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
    }
    return g;
}

The following is calculating $$$f(e^x)\bmod x^m.$$$

ply fex(ply f,int m){
    vector<R>a(f.size());
    Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
    R s(sum(a,0,a.size()-1)); //DC summation
    auto re(convolution(s.p,pinv(s.q,m)));
    re.resize(m);
    Frn0(i,0,m)re[i]/=fc[i]; //dividing by i!
    return re;
}

Code

Time Complexity: $$$O(n\log ^2n +m\log m)$$$ (DC summation and inverse series)

Memory Complexity: $$$O(n+m)$$$

Further details are annotated.

//This program is written by Brian Peng.
#include<bits/stdc++.h>
#include<atcoder/convolution>
using namespace std;
using namespace atcoder;
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define All(a) (a).begin(),(a).end()
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
using mint=modint998244353;
using ply=vector<mint>;
#define N (200010)
int n,m;
mint fc[N]{1};
ply ans;
ply pinv(ply f,int m);
ply add(ply a,ply b);
struct R{ply p,q;
    R operator+(R b){
        ply rp(add(convolution(q,b.p),convolution(p,b.q))),
            rq(convolution(q,b.q));
        return{rp,rq};
    }
}g;
vector<R>a;
R sum(vector<R>&a,int l,int r);
mint cmb(int n,int r){return fc[n]/(fc[r]*fc[n-r]);} //binomial numbers
ply fex(ply f,int m);
signed main(){
    Rd(n),Rd(m);
    Frn1(i,1,max(n,m))fc[i]=fc[i-1]*i; //factorials
    a.resize(n+1);
    mint niv(mint(n).inv());
    Frn1(i,0,n){
        a[i].p={(((n-i)&1)?-1:1)*cmb(n,i)};
        a[i].q={1,-niv*i};
    } //the terms of the summation in g(x)
    g=sum(a,0,n);
    Frn0(i,1,g.q.size())g.q[i]+=g.q[i-1]; //denominator dividing 1-x
    //by taking prefix sums, obtaining PGF
    ans=convolution(fex(g.p,m+1),pinv(fex(g.q,m+1),m+1));
    //obtaining MGF
    Frn1(i,1,m)wr((ans[i]*fc[i]).val()),Pe;
    //remember to multiply by i! as it is an EGF
	exit(0);
}
ply pinv(ply f,int m){
    ply g({f[0].inv()});
    f.resize(m);
    for(int s(2);s<2*m;s<<=1){
        auto tmp(convolution(convolution(g,g),
            ply(f.begin(),f.begin()+min(s,m))));
        g.resize(min(s,m));
        Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
    }
    return g;
}
ply add(ply a,ply b){
    if(a.size()<b.size())swap(a,b);
    Frn0(i,0,b.size())a[i]+=b[i];
    return a;
}
R sum(vector<R>&a,int l,int r){
    if(l==r)return a[l];
    int md((l+r)/2);
    return sum(a,l,md)+sum(a,md+1,r);
}
ply fex(ply f,int m){
    vector<R>a(f.size());
    Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
    R s(sum(a,0,a.size()-1));
    auto re(convolution(s.p,pinv(s.q,m)));
    re.resize(m);
    Frn0(i,0,m)re[i]/=fc[i];
    return re;
}

Extensions

An alternative way to find the PGF of $$$X$$$

We may track the number of rolls to get a new side showing up when $$$i$$$ sides have already shown up.

When $$$i$$$ sides have already shown up, the probability of getting a new side in a roll is $$$\frac{n-i}{n}.$$$ Let $$$X_i$$$ be the random variable of the number of rolls, then $$$X_i\sim \text{Geo}(\frac{n-i}{n}).$$$

As the PGF of $$$\text{Geo}(p)$$$ is $$$\frac{px}{1-(1-p)x},$$$ the PGF of $$$X_i$$$ is $$$G_{X_i}(x)=\frac{\frac{n-i}{n}x}{1-\frac{i}{n}x}=\frac{(n-i)x}{n-ix}.$$$ By Convolution Theorem of PGF, the PGF of the total number of rolls $$$X=\sum_{i=0}^{n-1} X_i$$$ is

$$$g(x)=G_X(x)=\prod_{i=0}^{n-1} \frac{(n-i)x}{n-ix}=\frac{n!x^n}{\prod_{i=0}^{n-1} (n-ix)}.$$$

It seems to be a lot easier to do... So the product of these small polynomials can still be done by a similar Divide-and-Conquer method.

//This program is written by Brian Peng.
#include<bits/stdc++.h>
#include<atcoder/convolution>
using namespace std;
using namespace atcoder;
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define All(a) (a).begin(),(a).end()
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
using mint=modint998244353;
using ply=vector<mint>;
#define N (200010)
int n,m;
mint fc[N]{1};
ply ans;
ply pinv(ply f,int m);
ply add(ply a,ply b);
struct R{ply p,q;
    R operator+(R b){
        ply rp(add(convolution(q,b.p),convolution(p,b.q))),
            rq(convolution(q,b.q));
        return{rp,rq};
    }
}g;
vector<ply>a;
R sum(vector<R>&a,int l,int r);
ply prod(vector<ply>&a,int l,int r); //DC Multiplication
ply fex(ply f,int m);
signed main(){
    Rd(n),Rd(m);
    Frn1(i,1,max(n,m))fc[i]=fc[i-1]*i;
    g.p.resize(n+1),g.p[n]=fc[n],a.resize(n);
    Frn0(i,0,n)a[i]={n,-i};
    g.q=prod(a,0,n-1);
    ans=convolution(fex(g.p,m+1),pinv(fex(g.q,m+1),m+1));
    Frn1(i,1,m)wr((ans[i]*fc[i]).val()),Pe;
	exit(0);
}
ply pinv(ply f,int m){
    ply g({f[0].inv()});
    f.resize(m);
    for(int s(2);s<2*m;s<<=1){
        auto tmp(convolution(convolution(g,g),
            ply(f.begin(),f.begin()+min(s,m))));
        g.resize(min(s,m));
        Frn0(i,0,min(s,m))g[i]=2*g[i]-tmp[i];
    }
    return g;
}
ply add(ply a,ply b){
    if(a.size()<b.size())swap(a,b);
    Frn0(i,0,b.size())a[i]+=b[i];
    return a;
}
R sum(vector<R>&a,int l,int r){
    if(l==r)return a[l];
    int md((l+r)/2);
    return sum(a,l,md)+sum(a,md+1,r);
}
ply prod(vector<ply>&a,int l,int r){
    if(l==r)return a[l];
    int md((l+r)/2);
    return convolution(prod(a,l,md),prod(a,md+1,r));
}
ply fex(ply f,int m){
    vector<R>a(f.size());
    Frn0(i,0,f.size())a[i].p={f[i]},a[i].q={1,-i};
    R s(sum(a,0,a.size()-1));
    auto re(convolution(s.p,pinv(s.q,m)));
    re.resize(m);
    Frn0(i,0,m)re[i]/=fc[i];
    return re;
}

It is really easier to implement, and took 300ms less time than the previous one...

THANKS FOR READING!

Full text and comments »

  • Vote: I like it
  • +94
  • Vote: I do not like it

By Bring, history, 2 years ago, In English

For a better experience please click here.

Solved the first five questions with brute force. Still 19 points away from reclaiming purple. Hang on!

Solution: CF1774G Segment Covering

Link to the question: CF Luogu

Preface

A brilliant and tricky question (tricky because modding $$$998244353$$$ is 'almost' of no use). This blog is an explanation and extension of the official tutorial.

Notations: We use $$$[l_i,r_i]$$$ to denote an interval with index $$$i$$$ (or interval $$$i$$$) and $$$[x,y]$$$ to denote a query interval.

Analysis

The question asks about the difference between the number of ways of covering $$$[x,y]$$$ with even and odd numbers of existing intervals. As the question does not ask the numbers of ways with even and odd intervals separately, but the difference between them, we need to take advantage of it.

Property 1

Suppose an interval contains another interval, i.e. there exist indices $$$i,j$$$ such that $$$l_i\leqslant l_j\leqslant r_j\leqslant r_i,$$$ or $$$[l_j,r_j]\subseteq [l_i,r_i].$$$ If we use the interval $$$[l_i,r_i]$$$ but not $$$[l_j,r_j]$$$ in a covering, then we can always pair it with another covering that is the same as the previous one except that $$$[l_j,r_j]$$$ is also used. The two coverings differ by $$$[l_j,r_j]$$$ only, so one contributes to $$$f$$$ and another contributes to $$$g.$$$ In the end, they contribute zero to the final answer.

Therefore, to have a non-zero contribution to the answer, we cannot use the interval $$$[l_i,r_i].$$$ In other words, we can remove $$$[l_i,r_i]$$$ from our list of intervals.

After removing all 'useless' intervals, if we sort the remaining intervals by their left boundaries, their right boundaries will also be sorted.

Property 2

Suppose $$$[x,y]$$$ is the query interval and the intervals $$$[l_i,r_i]$$$ are trimmed (by Property 1) and sorted in a list. Then, the intervals that might have a chance to be chosen are those contained by $$$[x,y]$$$ and are consecutive in the sorted list. We let them be $$${[l_i,r_i],[l_{i+1},r_{i+1}],\cdots,[l_{j},r_{j}]}.$$$

If the list is empty or $$$l_i\ne x$$$ or $$$r_i<l_{i+1}$$$ or $$$r_j\ne y,$$$ the answer is obviously $$$0.$$$ So, we suppose $$$l_i=x,r_i\geqslant l_{i+1},r_j=y,$$$ and we know that the interval $$$[l_i,r_i]$$$ must be chosen.

We consider the following case:

Here, the black line represents the query interval $$$[x,y],$$$ and the colored lines are intervals in our list. We know that the interval $$$i$$$ ($$$[l_i,r_i]$$$) must be chosen, so we color it green.

Next, we see that in the picture, $$$l_{i+2}\leqslant r_i,$$$ which means that the interval $$$i+2$$$ intersects with the interval $$$i.$$$ Thus, if we choose the interval $$$i+2$$$ in a covering, then choosing $$$i+1$$$ or not does not affect the covering. Similar to Property 1, this means that if we choose $$$i+2,$$$ the net contribution to the answer is $$$0.$$$ Therefore, the interval $$$i+2$$$ is useless in this case, and we color it red.

In a similar argument, all the intervals that intersect with the interval $$$i$$$ (except the interval $$$i+1$$$) are useless. We let the interval $$$k$$$ be the left-most interval that does not intersect with the interval $$$i.$$$

As we need to cover $$$[x,y]$$$ and the only interval that intersects with $$$i$$$ is $$$i+1,$$$ then we must choose $$$i+1,$$$ so we color it green. Now, the interval $$$i+1$$$ is a must-be-chosen interval and $$$k$$$ is a possible interval next to it. In the picture, we may see that $$$l_{k+1}\leqslant r_{i+1},$$$ so the interval $$$k+1$$$ is useless (Why?). In fact, every interval with an index greater than $$$k$$$ that intersects with the interval $$$i+1$$$ is useless, and then $$$k$$$ must be chosen, being the only 'non-useless' interval that intersects with $$$i+1$$$.

To conclude, a must-be-chosen interval and a possible interval next to it make all the other intervals intersecting with the first interval useless, and the second interval must-be-chosen.

From the above statement, we may show inductively that every interval is either must-be-chosen or useless, so there is essentially at most one 'useful' covering. If an even number of intervals are used, the answer is $$$1.$$$ If odd, then $$$-1.$$$ If the must-be-chosen intervals cannot cover $$$[x,y]$$$, then the answer is $$$0.$$$ The covering is like the following:

Note that we split the intervals into two "layers," and for every interval, only the intervals next to it intersect with it. In this picture, the interval $$$j$$$ is on the same layer as $$$i+1,$$$ so there are an even number of intervals and the answer is $$$1.$$$ If $$$j$$$ is on the layer of $$$i$$$, the answer is $$$-1.$$$

A Hidden Tree Structure

Property 2 already gives us a method of finding the useful covering for a query $$$[x,y],$$$ which is recursively seeking the must-be-chosen intervals and deleting useless intervals until an interval has its right boundary equal to $$$y$$$ is chosen. However, as there are many queries, optimization is needed.

Let's look closer at the picture above. The intervals are divided into two layers, one starting with $$$i$$$ and another starting with $$$i+1.$$$ Also, for every interval, if it is not the end, its next interval on the same layer is always the first interval on its right that is disjoint with it.

Therefore, if we link each interval to the interval on its right that is disjoint with it, a tree is formed. For simplicity, we link all the intervals that have no "parents" to a virtual root node.

Here is our final "theorem" of the question:

$$$\text{Theorem. }$$$There is a scheme of choosing the interval if and only if the interval $$$j$$$ is the ancestor of exactly one of $$$i$$$ and $$$i+1.$$$ If it is the ancestor of $$$i,$$$ then the answer is $$$-1.$$$ If $$$i+1,$$$ then the answer is $$$1.$$$

p.s. Please prove this theorem independently. There are two points worth noting. Firstly, if $$$j$$$ is the common ancestor of both $$$i$$$ and $$$i+1,$$$ then there is one point where intervals on both layers are disjoint with $$$j,$$$ so $$$[x,y]$$$ cannot be fully covered. Secondly, the official tutorial calculates the answer $$$\pm 1$$$ by counting the number of intervals used, but actually, we only need to check whose ancestor $$$j$$$ is.

Implementation

A trick of STL set: removing 'useless' intervals by Property 1

If an input interval contains another, we remove the larger one.

This can be done in multiple ways, we may sort the intervals in some manner and label the useless intervals, which is the method in the official tutorial.

Here is another way: we may maintain a set of intervals such that no interval is contained by another, through a specifically designed $$$<$$$ relation.

We define that the intervals $$$[l_1,r_1]<[l_2,r_2]$$$ if $$$l_1<l_2$$$ and $$$r_1<r_2.$$$ We may see that if an interval contains another, they are considered 'equal' by set (because neither $$$<$$$ nor $$$>$$$).

The algorithm is: When we try to add $$$[l,r]$$$ into the set, we use find() method to look for the interval that is 'equal' to $$$[l,r].$$$ If it does not exist, then we simply insert $$$[l,r]$$$ into the set.

Suppose it is $$$[a,b].$$$ If $$$[l,r]$$$ contains $$$[a,b],$$$ then we discard $$$[l,r].$$$

Lastly, if $$$[a,b]$$$ contains $$$[l,r],$$$ then we remove $$$[a,b]$$$ from the set and check if there are other intervals in the set that contains $$$[l,r].$$$ After removing all of them, we insert $$$[l,r].$$$

upper_bound: looking for 'parents'

We may use vector to store and index the ordered intervals remaining in the set. Now, for an interval $$$[l_i,r_i]$$$, its 'parent' is the first interval to its right that is disjoint to it. We may use bisection or two-pointers to achieve this. Here is another way:

We may still use STL and the $$$<$$$ relation we have designed. For the interval $$$[l_i,r_i],$$$ its 'parent' $$$[l_k,r_k]$$$ is exactly the first interval 'greater than' $$$[r_i,r_i],$$$ which may be found by upper_bound.

Similarly, we may use upper_bound to find the left-most interval for the query $$$[x,y],$$$ which is the first interval 'greater than' $$$[x-1,x-1].$$$

Binary lifting: ancestor?

Note that we need to check whether the interval $$$j$$$ (the unique interval $$$[l_j,r_j]$$$ such that $$$r_j=y$$$) exists and is the ancestor of $$$i$$$ or $$$i+1.$$$ This can be done by binary lifting, the most commonly used method for LCA. Starting from node $$$i$$$ (or $$$i+1$$$), we 'lift' it to the last interval with its right boundary $$$\leqslant y.$$$ Then, if its right boundary $$$=y,$$$ then it is interval $$$j$$$ and is an ancestor of $$$i$$$ (or $$$i+1$$$).

Code

Here is a sample code integrating all the ideas above. The binary-lifting array is f and the vector v stores all non-useless intervals in order, with indices from $$$0$$$ to v.size()-1. We let v.size() be the index of the virtual root node.

We use li to denote the first interval with its left boundary $$$\geqslant x,$$$ which is the interval $$$i$$$ in our analysis section. Note that there are many special cases, please read the code and make sure you understand all of the special cases.

//This program is written by Brian Peng.
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (200010)
int m,q,x,y,f[N][20];
struct T{int l,r;
	bool operator<(T b)const{return l<b.l&&r<b.r;}
};//The structure of intervals, with < relation defined.
set<T>st;
signed main(){
	Rd(m),Rd(q);
	while(m--){
		Rd(x),Rd(y);
		auto it(st.find({x,y}));//Find the interval in the set that contains
        //or is contained by the input [x,y]
		if(it==st.end())st.insert({x,y});
		else if(x<=it->l&&it->r<=y)continue;//[x,y] contains a smaller interval
		else{
			st.erase(it);//[x,y] is contained by a larger interval
			while((it=st.find({x,y}))!=st.end())st.erase(it);
            //Remove all the intervals that contain it
			st.insert({x,y});
		}
	}
	vector<T>v(st.begin(),st.end());
	Frn0(i,0,v.size())
        *f[i]=upper_bound(v.begin()+i+1,v.end(),T({v[i].r,v[i].r}))-v.begin();
        //Use upper_bound and < relation to find the parent node
	int lg(log2(v.size()));
	*f[v.size()]=v.size();
	Frn1(j,1,lg)Frn1(i,0,v.size())f[i][j]=f[f[i][j-1]][j-1];//binary lifting
	while(q--){
		Rd(x),Rd(y);
		int li(upper_bound(v.begin(),v.end(),T({x-1,x-1}))-v.begin());
        //li is the index of the first interval with left boundary >= x
		if(li==v.size()||v[li].l!=x||v[li].r>y){wr(0),Pe;continue;}
        //Special cases: li not existing, li not covering x, li exceeding y
		if(v[li].r==y){wr(998244352),Pe;continue;}
        //Special case: li is just [x,y]
		if(li+1==v.size()||v[li+1].r>y||v[li+1].l>v[li].r){wr(0),Pe;continue;}
        //Special cases concerning li+1
		int u(li),u2(li+1);
		Frn_(i,lg,0)if(f[u][i]<v.size()&&v[f[u][i]].r<=y)u=f[u][i];
		Frn_(i,lg,0)if(f[u2][i]<v.size()&&v[f[u2][i]].r<=y)u2=f[u2][i];
        //Binary lifting li and li+1 to the last intervals with right boundary <= y
		if(u==u2||v[u].r!=y&&v[u2].r!=y)wr(0),Pe;
        //Common ancestor or both not reaching y
		else if(v[u].r==y)wr(998244352),Pe;//Ancestor of li
		else wr(1),Pe;//Ancestor of li+1.
	}
	exit(0);
}

Time Complexity: $$$O((m+q)\log m)$$$

Memory Complexity: $$$O(m\log m)$$$

Extension

STL is a very powerful tool.

In this question, we use STL set to maintain a set of intervals that one does not contain another, but a specifically designed $$$<$$$ relation.

By designing different $$$<$$$ relations, we may maintain a set of intervals with different properties conveniently. For example, my blog

Solution: CF731D 80-th Level Archeology -- Letter, Interval, and Reverse Thinking

solves another CF problem by maintaining a set of mutually disjoint intervals by defining another $$$<$$$ relation between intervals.

Thanks for reading! See you next round!

Full text and comments »

  • Vote: I like it
  • +39
  • Vote: I do not like it

By Bring, history, 2 years ago, In English

For a better experience please click here.

Solution: CF731D 80-th Level Archeology -- Letter, Interval, and Reverse Thinking

Link to the question: CF Luogu

Preface

Assertion: "STL set is the most 'powerful' tool throughout C++."

"Not because it is capable of various tasks, but because we don't need to code it ourselves."

What is lexicographical order?

Lexicographical order is an ordering of strings based on the ordering of characters. A string is less than another string when, at the first position of difference, the letter in the first string is less than the corresponding letter in the second.

It is the ordering of words in a dictionary.

E.g. In the normal alphabet, 'abceb'<'abdca' since 'c'<'d' at their third position, the first position of difference. Note that we only consider the first position of difference. Even though 'e' is greater than 'c' at the fourth position and 'b' is greater than 'a' at the fifth, they cannot affect our comparison.

P.s. We let the 'empty character' be the smallest. E.g. To compare 'an' and 'and', we may consider the third position of 'an' as the empty character, then 'an'<'and' as the empty character is less than 'd'.

Back to the question

"One operation replaces each letter with the next letter in the alphabet."

We try to find the number of operations to make all the words in lexicographical order.

Reverse Thinking: One operation replaces each letter in the alphabet with its previous one.

E.g. Consider the first example (We use letters to make it more intuitive). The strings are

cb
a
bca
bcab

and the alphabet is a<b<c.

After one operation, the strings become

ac
b
cab
cabc

which are in lexicographical order.

Reverse Thinking: after one operation, we 'rotate the alphabet' to c<a<b. And now, the strings are in lexicographical order based on our new alphabet.

The solution is to find a reordering of the alphabet to make the strings in lexicographical order.

How to reorder the alphabet?

"From the order of characters, we develop an order of strings."

Reverse Thinking: From the order of strings, we deduce the order of characters.

E.g. Given that 'ace'<'abd' in some alphabet, we know that 'c'<'b' in that alphabet (and this is the only information we can extract).

The method is: Find the first position of difference, then the character in the first string at that position is less than the corresponding character in the second string.

P.s. As we know that the empty character is invariantly the smallest, there is no alphabet satisfying 'and'<'an'.

To deal with this question, we assume that the strings are in lexicographical order in some alphabet, and by comparing strings, we know some orders between characters in that alphabet. At last, we check whether a rotation of the switch results in an alphabet satisfying all conditions.

We know that after rotation, the new alphabet must be in the form of

$$$a,a+1,\cdots,n-1,n,1,2,\cdots,a-1$$$

for some $$$a\in{1,\cdots, n},$$$ which is the smallest character (except empty character).

In the following text, we use '$$$<$$$' to denote the normal comparison between numbers, and '$$$\prec$$$' to denote the comparison in alphabetical order.

Case 1:

By comparing two strings, we know that $$$u\prec v,$$$ while $$$u<v.$$$

E.g. If the strings 1 2 3 < 1 3 2, then we know that $$$2\prec 3$$$ in this alphabet, while $$$2<3$$$ in number order.

We may use $$$u$$$ and $$$v$$$ to restrict the range of $$$a,$$$ the smallest character in the new alphabet.

The restriction is: $$$a\notin [u+1,v]$$$ because contrarily, if $$$a$$$ is in the range, the alphabet is

$$$a,a+1,\cdots,v,\cdots,n,1,\cdots,u,\cdots,a-1,$$$ where $$$v\prec u$$$ in the alphabet, a contradiction.

We may check that $$$a$$$ outside the interval satisfies $$$u\prec v.$$$

Case 2:

By comparing two strings, we know that $$$u\prec v,$$$ while $$$u>v.$$$

E.g. If the strings 1 3 2 < 1 2 3, then we know that $$$3\prec 2$$$ in this alphabet, while $$$3>2$$$ in number order.

The restriction now is: $$$a\in [v+1,u].$$$ (Why?)

Implementation

To combine all the restrictions, $$$a$$$ is

  1. not in the union of all the intervals of case 1;
  2. in the intersection of all the intervals of case 2.

For case 2, maintaining the intersection of intervals is easy. Suppose the current interval is $$$[l,r],$$$ then its intersection with $$$[u,v]$$$ is just $$$[\max(l,u),\min(r,v)].$$$ (If the right boundary is less than the left, then we regard it as the empty interval.)

For case 1, maintaining the union of intervals is not so easy. One offline method is to sort all the intervals by left boundaries, and then combine them if they have non-empty intersections.

However, I'd like to introduce an online method of maintaining unions of intervals, which is, sometimes, even easier to implement than the offline method.

Use STL set to maintain the union of intervals.

That's why it is the most powerful tool in C++

We let the set store disjoint intervals (i.e. intervals that don't intersect) in order. Whenever we try to add a new interval, we combine it with the existing intervals that intersect with it.

Our data structure:

struct T{int l,r;
	bool operator<(T b)const{return r<b.l;}
};
set<T>//our data structure

Note the way we define the order between intervals (which is necessary when using set). We define an interval to be less than another only when the first is on the left of the second and they are disjoint. Thus, if two intervals intersect, the set considers them equal (more exactly, neither less than nor greater than).

Another important feature of set is its find() function, which returns the iterator (the pointer) to the first element that is 'equal' to our input.

Now, suppose we have a set with intervals $$${[1,3],[5,5],[7,9],[10,11]}$$$ and we want to add the interval $$$[2,8]$$$. What should we do?

  1. We use find() to search for the first interval that intersects with $$$[2,8],$$$ which is $$$[1,3].$$$ Combine them to get $$$[1,8].$$$ Erase $$$[1,3]$$$ in the set. Now, our task becomes "adding $$$[1,8]$$$ to the set $$${[5,5],[7,9],[10,11]}$$$".
  2. We find $$$[5,5]$$$ in the set. As it is already contained in $$$[1,8]$$$ we erase it from the set. Now, our task becomes "adding $$$[1,8]$$$ to the set $$${[7,9],[10,11]}$$$".
  3. We find $$$[7,9]$$$ in the set. Combine it with $$$[1,8]$$$ to get $$$[1,9]$$$. Now, our task becomes "adding $$$[1,9]$$$ to the set $$${[10,11]}$$$".
  4. As there is no interval in the set that interests with $$$[1,9],$$$ we add it to the set. The set becomes $$${[1,9],[10,11]}.$$$ Terminate.

The general procedure is: When adding an interval $$$[u,v]$$$ to the set $$$S$$$,

  1. if no interval in $$$S$$$ intersects with $$$[u,v],$$$ then add it by insert() method.
  2. otherwise, find the interval $$$[l,r]\in S$$$ that intersects with $$$[u,v].$$$
  3. combine them to get $$$[\min(u,l),\max(v,r)]$$$ and make it the new $$$[u,v].$$$
  4. remove $$$[l,r]$$$ from $$$S.$$$ Repeat step 1.

Derive the answer

  • Let $$$[l,r]$$$ be the intersection of intervals from case 2, then $$$a\in[l,r].$$$

  • Let $$$S$$$ be the union of intervals from case 1, then $$$a$$$ is not contained by any interval in $$$S$$$.

We may reformulate the second point: "$$$[a,a]$$$ does not intersect with any interval in $$$S.$$$" This makes our lives easier as we may use find() method to check if there is an interval in $$$S$$$ that is 'equal' to $$$[a,a],$$$ which means there is an interval that intersects with $$$[a,a].$$$

Lastly, if $$$a$$$ really exists, the number of operations is $$$(c+1-a)\bmod c.$$$ (Why?)

Complexity

Time: as there are $$$n$$$ strings, there are $$$O(n)$$$ restrictions (every comparison between adjacent strings gives at most one restriction). Each restriction in case 1 is an interval, which can only be inserted into the set, found by find(), and erased once. The set operations above take $$$O(\log n)$$$ times each (as the set has $$$O(n)$$$ elements), so overall the time complexity is $$$O(n\log n).$$$

Memory: $$$O(n).$$$

Code

In the following code, we use $$$[pl,pr]$$$ as the interval of 'possibility' (case 2). The set for case 1 is named ip for 'impossiblilty'.

//This program is written by Brian Peng.
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (1000010)
int n,c,w[2][N],l[2],pl(1),pr;
bool t;
struct T{int l,r;
	bool operator<(T b)const{return r<b.l;}
};
set<T>ip;//set of case 1
set<T>::iterator it;
void mdf(int u,int v);//adding a restriction "u is less than v in the alphabet"
signed main(){
	Rd(n),pr=Rd(c);//the possibility interval is originally set as [1,c]
	Rd(l[t]);
	Frn1(i,1,l[t])Rd(w[t][i]);
	Frn1(i,2,n){
		t^=1,cin>>l[t];
		Frn1(i,1,l[t])Rd(w[t][i]);
		w[t][l[t]+1]=0;//set the empty character
		Frn1(i,1,l[t^1])if(w[t][i]!=w[t^1][i]){
			if(!w[t][i])wr(-1),exit(0);//the case when the empty character
			//is greater than another character, which is impossible
			mdf(w[t^1][i],w[t][i]);//add restriction
			break;
		}
	}
	Frn1(i,pl,pr)if(ip.find({i,i})==ip.end())wr((c+1-i)%c),exit(0);
	//check whether i is in the interval [pl,pr] for case 2
	//while not contained in the set for case 1
	wr(-1),exit(0);
}
void mdf(int u,int v){
	if(u<v){//case 1
		++u;//the interval added is [u+1,v]
		while((it=ip.find({u,v}))!=ip.end())//find the interval in the set
		//that intersects with it
			u=min(u,it->l),v=max(u,it->r),ip.erase(it);//combine and remove
		ip.insert({u,v});//insert the combined interval
	}else pl=max(pl,v+1),pr=min(pr,u);//case 2
}

Extension

What if the question becomes that you can reorder the alphabet to any permutation of $$${1,\cdots,c}$$$?

We still need to find the restriction between characters. But now, as we don't have a specific reordering pattern, we may treat each restriction $$$u\prec v$$$ as a directed edge $$$u\to v$$$ in a graph with vertices $$${1,\cdots, c}.$$$ Then, we may find the reordering of the alphabet as a topological ordering of the vertices, if it exists (i.e. the graph is a DAG).

Do you want to ask me why I thought about this? Because I originally thought this is the solution, but I found it too difficult to find a topological order being a cycle of $$$(1,2,\cdots,n).$$$ Eventually, I came up with this solution, making use of the reordering being a cycle.

A slight change in question leading to a totally different method, maybe this is just why algorithms are so intriguing...

Thanks for your reading. ありがとう!

Full text and comments »

  • Vote: I like it
  • 0
  • Vote: I do not like it

By Bring, history, 2 years ago, In English
For better experience please click here.

First CF round at Cambridge. Solved A,B,D1 in the round. Dropped from purple to blue...

Still a long way to go...

Solution: CF Round #830 (Div. 2) D1&D2 Balance

Easy Version

Brute-force

Evidently the most brute-force way is to create a set to collect the $$$x$$$ added. Then for all query with $$$k,$$$ check $$$k,2k,3k,\cdots$$$ till the first multiple of $$$k$$$ that is not contained in the set. Output it.

Obviously it is doomed to TLE, especially when you are queried by the same $$$k$$$ multiple times with very large $$$k\text{-mex}$$$.

Becoming Elegant

We try to optimize the brute-force by reducing the time cost if queried by the same $$$k.$$$ As there is no remove, if you are queried by $$$k$$$ and you find the $$$k\text{-mex},$$$ it is obvious that the next time if you are queried by the same $$$k,$$$ the answer must be greater than or equal to the previous one.

Therefore, we can memorize all the "previous answers." If $$$k$$$ that has a memorized answer is queried, we start checking the set from its previous answer instead of from $$$1\cdot k.$$$

Wait, do we avoid TLE just by this "subtle" optimization?

Calculation of time complexity (not rigorous):

For queries with the same $$$k,$$$ every multiples of $$$k$$$ in the set is checked at most once. So the time complexity is the same as if every query is moved to the end of the operations, and every $$$k$$$ is queried at most once.

Then, the worst case happens (intuitively) when the first $$$q/2$$$ operations fill the set with numbers between $$$1$$$ and $$$q/2,$$$ and the next $$$q/2$$$ operations query for $$$k=1,2,\cdots,q/2.$$$ In this case, the number of times checking the set is $$$O(\sum_{k=1}^{q/2} \frac{q}{2k})=O(q\log q)$$$ by harmonic series. As every check of the set takes $$$O(\log q)$$$ of time, the overall time complexity is $$$O(q\log^2 q).$$$

Code (795 ms / 12600 KB)

We use a map to memorize the previous answers. The function $$$\mathtt{srch(x,fs)}$$$ takes $$$\mathtt x$$$ as queried $$$k$$$ and $$$\mathtt{fs}$$$ its starting number.

//This program is written by Brian Peng.
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
int q,x;
char opt;
set<int>s;
map<int,int>ans; //Memorization
int srch(int x,int fs);
signed main(){
	Rd(q);
	while(q--){
		cin>>opt,Rd(x);
		if(opt=='+')s.insert(x);
		else{
			if(ans.find(x)!=ans.end())ans[x]=srch(x,ans[x]);
			else ans[x]=srch(x,x);
			wr(ans[x]),Pe;
		}
	}
	exit(0);
}
int srch(int x,int fs){
	while(s.find(fs)!=s.end())fs+=x;
	return fs;
}

Hard Version

Now the "remove" operation is added, and we can no longer memorize the previous answers simply.

Maybe we can use something more powerful, which is able to record "removed" $$$x$$$'s?

The most useful tool to record and query the existence of numbers in a given range is

Segment Tree

For every queried $$$k,$$$ instead of memorizing the previous $$$k\text{-mex}$$$, we build a segment tree of $$$k$$$ recording the checked and not removed multiples of $$$k.$$$ In the following text, we let $$$\text{St}_k$$$ denote the "SegTree of $$$k$$$", and use $$$x\in \text{St}_k$$$ to denote that $$$x$$$ is recorded in the SegTree of $$$k.$$$

For a query with $$$k$$$, if $$$\text{St}_k$$$ is not set up yet (i.e. $$$\text{St}_k$$$ is empty), we go through the multiples of $$$k$$$ in the set, which are $$$k,2k,3k,\cdots$$$ till the first multiple of $$$k$$$ (say $$$nk$$$) that is not in the set. Then, the SegTree of $$$k$$$ is built with the entries from $$$1$$$ to $$$n-1$$$ set as $$$1,$$$ meaning that $$$k,2k,\cdots,(n-1)k\in\text{St}_k.$$$ (As we only insert multiples of $$$k$$$ into $$$\text{St}_k,$$$ we let the $$$i$$$ th entry of $$$\text{St}_k$$$ represent the number $$$ik$$$ to make the tree more compact.)

Thus, if a number is recorded in a SegTree, it is in the set.

Then, for removal of $$$x$$$, we need to remove $$$x$$$ from not only the set, but also from every SegTree that records it. To achieve this, we create a list of $$$x$$$ (say $$$\text{Tk}_x$$$) that consists of all the $$$k$$$'s such that $$$x\in\text{St}_k.$$$ In other words, if a certain $$$x$$$ is recorded in the SegTree of $$$k$$$, we add $$$k$$$ into the list $$$\text{Tk}_x$$$ so that when $$$x$$$ is removed from the set, we remove $$$x$$$ from all SegTrees recording it by going through every $$$k$$$ in $$$\text{Tk}_x$$$ and setting the $$$x/k$$$ th entry in $$$\text{St}_k$$$ to be $$$0.$$$ We clear $$$\text{Tk}_x$$$ at the end of the process as $$$x$$$ is no longer recorded in any SegTree.

Now, if a $$$k$$$ is queried a second time, we find the least entry in $$$\text{St}_k$$$ that is $$$0.$$$ (This is why we need to use a SegTree instead of an array, as we may check whether a sub-interval is set all $$$1$$$ by checking if the sum of the interval is equal to its length.) Say this entry is the $$$n$$$ th. If $$$nk$$$ is not in the set, we output $$$nk$$$ as $$$k\text{-mex}.$$$ Otherwise, if $$$nk$$$ is in the set, we update the $$$n$$$ th entry in $$$\text{St}_k$$$ to be $$$1,$$$ add $$$k$$$ into the list $$$\text{Tk}_{nk},$$$ and repeat the process of seeking the least entry in $$$\text{St}_k$$$ that is $$$0.$$$

Code Implementation

As the range of $$$k$$$ and $$$x$$$ in the input is very large, I use #define int long long (a wicked trick) for convenience and signed is used in place of int if such a large range is not needed.

Lazy Created Segment Tree

We may see that most of the entries in a SegTree are $$$0,$$$ and most of the $$$k$$$'s even do not have a SegTree if they are never queried. Thus, we need Lazy Created SegTree to reduce time and memory complexity.

The following is the structure of a node in a lazy created SegTree:

struct SgT{signed ls,rs,v;}t[10000000];

Here, $$$\mathtt {ls,rs}$$$ represent the ids of left/right-son respectively, and $$$\mathtt v$$$ represents the sum of the interval the node represents. (The interval is not stored in the nodes explicitly, but they will be clear in functions.)

How lazy creation is achieved
  1. We use a counter $$$\mathtt{tn}$$$ (initial value $$$0$$$) to record the highest id of the SegTree nodes. Then whenever a new node is created, we add $$$1$$$ to $$$\mathtt{tn}$$$ and use it as the id of the new node.

  2. Particularly, the node with id $$$0$$$ represents an interval with entries all $$$0,$$$ and at the beginning every SegTree has only node $$$0.$$$ If a son of a node is $$$0,$$$ it means that its half-interval is filled with $$$0.$$$

  3. We use a map map<int,signed>rt to store the root of $$$\text{St}_k$$$ (rt[k]). For every SegTree, we set its root interval be $$$[1,q]$$$ as any number greater than or equal $$$(q+1)k$$$ can never be $$$k\text{-mex}.$$$ (Why?)

  4. We also use a map map<int,list<int>>tk to store the lists $$$\text{Tk}_x$$$ (tk[x]).

Note: apart from the SegTree, the use of $$$\mathtt{map}$$$ for roots and lists are also Lazy Creation.

For convenience, we use $$$\mathtt u$$$ as the id of the node we are dealing with in a function, and we use #define to simplify the id of its two sons:

#define Ls (t[u].ls)
#define Rs (t[u].rs)

How let's look at how these operations are implemented.

Query: Lazy Creation, Updating, and Query in one function

Suppose we are dealing with the SegTree $$$\text{St}_k.$$$ The qry function $$$\mathtt{qry(u,k,l,r)}$$$ returns the least $$$n$$$ in the interval $$$[l,r]$$$ such that $$$nk$$$ is not in the set. If the interval is all filled with $$$1,$$$ return $$$0$$$ as the default value.

int qry(signed&u,int k,int l,int r){
	if(!u)u=++tn;
	//Lazy Creation happen's here. !!!IMPORTANT: Pass u by Reference!!!
	if(t[u].v==r-l+1)return 0;
	//If the sum is equal to length, then every entry is 1.
	if(l==r){
		//Check if l*k is in the set.
		if(st.find(l*k)!=st.end()){
			t[u].v=1,tk[l*k].push_back(k);
			//l*k is in the set. Update the SegTree and add k into the list tk[l*k].
			return 0;
		}
		else return l;
		//l*k is not in the set, return l (meaning the current k-mex is l*k).
	}
	int md((l+r)/2),ql(qry(Ls,k,l,md));
	//Query the left half-interval first
	if(ql){
		//Found the k-mex, update the SegTree and return.
		t[u].v=t[Ls].v+t[Rs].v;
		return ql;
	}
	//Left half-interval filled with 1. Query the right-interval.
	int qr(qry(Rs,k,md+1,r));
	t[u].v=t[Ls].v+t[Rs].v;
	return qr;
}

Removal

The modification function $$$\mathtt{mdf(u,l,r,x)}$$$ set the $$$\mathtt x$$$ th entry (Note that the $$$\mathtt x$$$ th entry represents the number $$$\mathtt xk$$$ recorded in $$$\text{St}_k$$$) to be $$$0$$$ in the SegTree with root $$$\mathtt u.$$$ For the SegTree $$$\text{St}_k,$$$ if we want to remove the number $$$x,$$$ we implement mdf(rt[k],1,q,x/k).

// When implementing, always set l=1 and r=q.
void mdf(signed u,int l,int r,int x){
	while(1){
		//Descending from the root to the leaf.
		--t[u].v;
		if(l==r)return;
		int md((l+r)/2);
		x<=md?(r=md,u=Ls):(l=md+1,u=Rs);
		//Direction chosen by x.
	}
}

Time complexity Calculation (Not rigorous)

As every checking of set ($$$O(\log q)$$$) is accompanied by a SegTree search ($$$O(\log q)$$$ as the SegTree interval is $$$[1,q]$$$) and possibly a SegTree modification (also $$$O(\log q)$$$) for a "remove" later, the time complexity is the same as the easy version: $$$O(q\log^2 q).$$$

Code (1200 ms / 162500 KB)
//This program is written by Brian Peng.
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=rd())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int rd(){
	int x;char c(getchar());bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	c^'-'?(k=1,x=c&15):k=x=0;
	while(isdigit(Gc(c)))x=x*10+(c&15);
	return k?x:-x;
}
void wr(int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(int i(a);i<(b);++i)
#define Frn1(i,a,b) for(int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define Ls (t[u].ls)
#define Rs (t[u].rs)
int q,x,tn;
char opt;
map<int,signed>rt;
map<int,list<int>>tk;
set<int>st;
struct SgT{signed ls,rs,v;}t[10000000];
int qry(signed&u,int k,int l,int r);
void mdf(signed u,int l,int r,int x);
signed main(){
	Rd(q);
	Frn1(i,1,q){
		cin>>opt,Rd(x);
		if(opt=='+')st.insert(x);
		else if(opt=='-'){
			st.erase(x);
			//Remove x from the set, and go through the list tk[x] if exists.
			if(tk.find(x)!=tk.end()){
				for(int k:tk[x])mdf(rt[k],1,q,x/k);
				tk.erase(x);
				//Remove the list tk[x].
			}
		}else{
			signed tmp(rt[x]?rt[x]:(rt[x]=++tn));
			//As Pass by Reference cannot be used with map,
			//we do lazy creation manually
			wr(qry(tmp,x,1,q)*x),Pe;
		}
	}
	exit(0);
}
int qry(signed&u,int k,int l,int r){
	if(!u)u=++tn;
	if(t[u].v==r-l+1)return 0;
	if(l==r){
		if(st.find(l*k)!=st.end()){
			t[u].v=1,tk[l*k].push_back(k);
			return 0;
		}
		else return l;
	}
	int md((l+r)/2),ql(qry(Ls,k,l,md));
	if(ql){
		t[u].v=t[Ls].v+t[Rs].v;
		return ql;
	}
	int qr(qry(Rs,k,md+1,r));
	t[u].v=t[Ls].v+t[Rs].v;
	return qr;
}
void mdf(signed u,int l,int r,int x){
	while(1){
		--t[u].v;
		if(l==r)return;
		int md((l+r)/2);
		x<=md?(r=md,u=Ls):(l=md+1,u=Rs);
	}
}

Conclusion

Why is Lazy Created SegTree effective in the hard version problem? An intuitive explanation:

In the easy version of the problem, there is no remove, so the non-decresing nature of $$$k\text{-mex}$$$ with time for a fixed $$$k$$$ leads us to the idea of storing answers, so that we can "move up" from the previous answer in a later query of the same $$$k$$$.

In the hard problem, the non-decresing nature of $$$k\text{-mex}$$$ is destroyed by the remove operation, and we can no longer record previous answer only. Recording the "checked numbers" in a SegTree, on the other hand, provides us with an efficient way to "move back" to a removed $$$x,$$$ and "jump up" if the removed $$$x$$$ is inserted into the set again.

Last but not least, the idea of Lazy Creation speeds our code up by creating the data structure only when they are to be used. This idea is extremely useful when the data range ($$$1\leqslant x,k\leqslant 10^{18}$$$) is a lot larger than the number of operations ($$$1\leqslant q\leqslant 2\cdot 10^5$$$).

Thanks for reading! See you next round!

Full text and comments »

  • Vote: I like it
  • +44
  • Vote: I do not like it