更快的 NTT
题意
给定两个序列 $A$和 $B$,$Alice$从 $A$序列中随机抽取一个数字 $x$,$Bob$从 $B$序列中随机抽取一个数字 $y$,求 $(x+y)^t$的期望在模 $998244353$意义下的值。
前置理论
期望的线性性
设 $a,b$是两个不相关的随机变量,$E(x)$为 $x$的期望取值,则:
$$
E(a\cdot b)=E(a)\cdot E(b)
$$
证明:
如果 $a,b$的取值是离散的(比如题目中的 $a,b$,只能从给定的有限个数中选择),那么可以设 $P_a(x)$为 $a$取到 $x$的概率
则我们要求的 $E(a\cdot b)=\sum_{\forall x,y}{P_a(x)P_b(y)}$
又因为 $E(a)\cdot E(b)=\sum_{\forall x} P_a(x)\sum_{\forall y} P_b(y)$
将下面的式子的 $P_a(x)$乘进后面的 $\sum$即可发现两个式子是相等的。
但是,要注意,如果 $a,b$是相关的,那么上述性质就不再适用了。
整数 $k$次幂和
对于一个数列,如果我们要求:对于所有的 $t\in [0,k]$,$A_i^t$的和,怎么做呢?
方法 1:
对于这个问题,我们有一个 $O(nlog^2n)$的做法。
写出这个和关于 $t$的生成函数 (函数中 $x^c$系数是 $c$次幂和)
$$
1+a_1^1x+a_1^2x^2+a_1^3x^3+\cdots+\\
1+a_2^1x+a_2^2x^2+a_2^3x^3+\cdots +\\
\cdots \\
$$
根据等比数列求和公式 (在生成函数中的推广),我们可以知道,上面的式子实际上是:
$$
\frac{1}{1-a_1x}+\frac{1}{1-a_2x}+\cdots+\frac{1}{1-a_nx}
$$
如果我们暴力从左往右合并这个式子,时间复杂度是:$O(n^2)$的,因为每进行一次通分,分子最高项次数必然增加。
如果逐层合并这个式子,时间复杂度可以变成 $O(nlog^2n)$。逐层的意思,即每一层合并所有相邻的分式,且不重复合并。这个过程类似于第二种方法中的” 分治 FFT“,所以复杂度证明将会与方法二相同,这里不再赘述。
方法 2:
一位大神坐在书桌前,对着眼前的《高等数学》略有所思,他观察 $ln(1+x)$的泰勒展开后,捻灭了手中快要燃尽的雪茄,又若有所悟地点点头,合上了桌上的书。
他发现了什么?
$$
ln(1+x)=x-\frac{x^2}{2}+\frac{x^3}{3}-\frac{x^4}{4}+\cdots \\
ln(1+ax)=ax-\frac{a^2x^2}{2}+\frac{a^3x^3}{3}-\frac{a^4x^4}{4}+\cdots
$$
也就是说,$ln(1+ax)$的泰勒展开里隐藏了 $a$的 $k$次幂和。
那么我们怎么利用这个求多个 $a$的 $k$次幂和呢?
$$
\sum_{i=1}^nln(1+a_ix)=ln(\Pi_{i=1}^{n}(1+a_ix))
$$
假设我们已经搞出了
$$
F=ln(\Pi_{i=1}^{n}(1+a_ix))
$$
由于取对数的函数是个多项式,所以我们需要用点巧办法:
$$
ln(x)’=\frac{1}{x} \\
ln(x)=\int \frac{1}{x}
$$
由于多项式求逆是很方便的,而多项式积分更方便 (具体实现请参考其他资料),因此多项式求对数也很方便。
现在的主要问题就是:如何求 $\Pi_{i=1}^{n}(1+a_ix)$
如果我们把所有要求积的括号分成两部分分别求积,然后合并,那么复杂度将是:
$$
T(n)=T(\frac{n}{2})+O(nlogn)
$$
故 $T(n)=O(nlog^2n)$
题目分析
首先,由于我们要求 $(x+y)^t$的期望,根据第一个前置知识,我们就可以先将这个式子展开
$$
E(x+y)^t=\sum_{i=0}^tC_t^iE(x^i)E(y^{t-i})
$$
所以,我有种预感,就是为了求对于所有 $t$ 的答案,我们只要分别求出对于所有的 $t$,$x^t$的期望和 $y^t$的期望,就可以用卷积的方式快速求出 $(x+y)^t$的期望。
实际上,的确是这样的:
$$
\begin{aligned}
F(t)& =\sum_{i=0}^tC_t^iA(i)B(t-i) \\
F(t)& =\sum_{i=0}^t\frac{t!}{i!(t-i)!}A(i)B(t-i) \\
\frac{1}{t!}F(t)& =\sum_{i=0}^t\frac{1}{i!}A(i)\cdot\frac{1}{t-i}B(t-i)
\end{aligned}
$$
这是一个卷积的形式,被卷积的两函数分别是 $\frac{1}{i!}A(i)$和 $\frac{1}{t-i}B(t-i)$
所以下面的任务就是求 $A(i),B(i)$,也就是 $a,b$的 $i$次幂期望。这个期望就是 $i$次幂和除以项数,用第二个前置知识解决即可,总复杂度 $O(nlog^2n)$。
更快的 NTT
考虑到一般的 NTT 都略慢于 FFT,优化 NTT 的常数以加速计算势在必行。
本题中由于需要进行分治 FFT,NTT 执行次数将非常多,所以优化的主要方法有两个:
- 1. 将一切可以预处理的东西预处理
- 2. 将一切可以省略的取模省略
那么原来的 NTT 是长这样的:
void dft(int f)
{
for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
for(int w=1;w<len;w<<=1)
{
ll w1=ksm(G,(MOD-1)/(w<<1));
if(f==-1)w1=inv(w1);
for(int j=0;j<len;j+=(w<<1))
{
ll wn=1;
for(int i=j;i<j+w;i++)
{
ll a=x[i],b=x[i+w];
x[i]=(a+wn*b)%MOD;
x[i+w]=(a-wn*b%MOD+MOD)%MOD;
wn=wn*w1%MOD;
}
}
}
if(f==-1)
{
ll mul=ksm(len,MOD-2);
for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
}
}
经过我比较彻底的常数优化后,成了这样:
void dft(int f)
{
for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
for(int w=1,b=1;w<len;w<<=1,b++)
{
ll *wx=(f==1?wi[b]:wii[b]); //wi 和 wii 为原根的幂及幂的逆
for(int j=0;j<len;j+=(w<<1))
{
for(int i=j;i<j+w;i++)
{
ll a=x[i],b=wx[i-j]*x[i+w];
x[i]=(a+b)%MOD;
x[i+w]=(a-b+MOD*MOD)%MOD;
}
}
}
if(f==-1)
{
ll mul=ksm(len,MOD-2);
for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
}
}
在洛谷”【模板】多项式乘法” 中,效率提升 20%
在洛谷” 玩游戏” 中,效率提升 48%
代码
下面的代码是优化过的方法一的代码。
不优化时无法通过本题。
#pragma GCC optimize("Ofast")
#pragma GCC target("sse3","sse2","sse")
#pragma GCC target("avx","sse4","sse4.1","sse4.2","ssse3")
#pragma GCC target("f16c")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"
#pragma GCC diagnostic error "-std=c++14"
#include <cstdlib>
#include <iostream>
#include <cstring>
#include <cstdio>
#define MX (524288+10)
#define MOD 998244353LL
#define G 3
using namespace std;
typedef long long ll;
int len,bit,rev[MX];
int n1,n2,m,A[MX],B[MX];
ll fac[MX];
ll EA[MX],EB[MX];
ll gp[MX],gpi[MX],wi[22][MX],wii[22][MX];
ll read()
{
ll x=0;char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x;
}
void qm(ll& x){x%=MOD;if(x<0)x+=MOD;}
void init(int l)
{
len=1,bit=0,rev[0]=0;
while(len<l)len<<=1,bit++;
for(int i=1;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
ll ksm(ll x,ll t)
{
ll ans=1;
while(t)
{
if(t&1)ans=ans*x%MOD;
x=x*x%MOD;
t>>=1;
}
return ans;
}
ll inv(ll x){return ksm(x,MOD-2);}
struct poly
{
ll *x;
void resize(){delete x;x=new ll[len];memset(x,0,len*sizeof(ll));}
void reset(){memset(x,0,len*sizeof(ll));}
void input(int s){for(int i=0;i<s;i++)scanf("%lld",&x[i]),qm(x[i]);}
void output(int s){for(int i=0;i<s;i++)printf("%lld ",x[i]);putchar('\n');}
void dft(int f)
{
for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
for(int w=1,b=1;w<len;w<<=1,b++)
{
ll *wx=(f==1?wi[b]:wii[b]);
for(int j=0;j<len;j+=(w<<1))
{
for(int i=j;i<j+w;i++)
{
ll a=x[i],b=wx[i-j]*x[i+w];
x[i]=(a+b)%MOD;
x[i+w]=(a-b+MOD*MOD)%MOD;
}
}
}
if(f==-1)
{
ll mul=ksm(len,MOD-2);
for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
}
}
};
poly up[MX],dn[MX];
poly dl,ul,dr,ur,dv,f1,f2;
poly tmp;
void get_inv(poly &a,poly &b,int nx)
{
if(nx==1)b.x[0]=inv(a.x[0]);
else
{
get_inv(a,b,nx>>1);
for(int i=0;i<nx;i++)tmp.x[i]=a.x[i],tmp.x[i+nx]=0;
init(nx<<1);
tmp.dft(1);
b.dft(1);
for(int i=0;i<len;i++)
tmp.x[i]=(b.x[i]*2-tmp.x[i]*b.x[i]%MOD*b.x[i]%MOD+MOD)%MOD;
tmp.dft(-1);
for(int i=0;i<nx;i++)b.x[i]=tmp.x[i],b.x[i+nx]=0;
}
}
void work(int *val,ll *tar,int sz,int ori)
{
init(2);
for(int i=0;i<sz;i++)
{
up[i].resize();
dn[i].resize();
if(i<ori)up[i].x[0]=1;
dn[i].x[0]=1;
dn[i].x[1]=-val[i],qm(dn[i].x[1]);
}
for(int i=1;i<sz;i<<=1)
{
init(min(i<<2,sz<<1));
dl.resize();
dr.resize();
ul.resize();
ur.resize();
for(int j=0;j+i<sz;j+=(i<<1))
{
dl.reset();
dr.reset();
ul.reset();
ur.reset();
for(int k=0;k<(len>>1);k++)dl.x[k]=dn[j].x[k];
for(int k=0;k<(len>>1);k++)dr.x[k]=dn[j+i].x[k];
for(int k=0;k<(len>>1);k++)ul.x[k]=up[j].x[k];
for(int k=0;k<(len>>1);k++)ur.x[k]=up[j+i].x[k];
dn[j].resize();
up[j].resize();
dl.dft(1);
dr.dft(1);
ul.dft(1);
ur.dft(1);
for(int k=0;k<len;k++)dn[j].x[k]=dl.x[k]*dr.x[k]%MOD;
for(int k=0;k<len;k++)up[j].x[k]=(ul.x[k]*dr.x[k]+dl.x[k]*ur.x[k])%MOD;
dn[j].dft(-1);
up[j].dft(-1);
}
}
dv.resize();
tmp.resize();
get_inv(dn[0],dv,len>>1);
dv.dft(1);
up[0].dft(1);
for(int i=0;i<len;i++)tmp.x[i]=dv.x[i]*up[0].x[i]%MOD;
tmp.dft(-1);
for(int i=0;i<sz;i++)tar[i]=tmp.x[i];
}
void calc_ans()
{
init(max(max(n1,m),n2));
int slen=len;
work(A,EA,slen,n1);
work(B,EB,slen,n2);
init(slen*2);
f1.resize();
f2.resize();
for(int i=0;i<len>>1;i++)f1.x[i]=EA[i]*inv(fac[i])%MOD*inv(n1)%MOD,f2.x[i]=EB[i]*inv(fac[i])%MOD*inv(n2)%MOD;
f1.dft(1);
f2.dft(1);
for(int i=0;i<len;i++)f1.x[i]=f1.x[i]*f2.x[i]%MOD;
f1.dft(-1);
for(int i=1;i<m;i++)printf("%lld\n",f1.x[i]*fac[i]%MOD);putchar('\n');
}
void prework()
{
fac[0]=1;
for(int i=1;i<MX;i++)fac[i]=fac[i-1]*(ll)i%MOD;
for(int i=0;i<MX;i++)gp[i]=ksm(G,(MOD-1)>>i),gpi[i]=inv(gp[i]);
for(int w=1,i=1;w<MX;i++,w<<=1)
{
wi[i][0]=1;
wii[i][0]=1;
for(int j=1;j<w;j++)
{
wi[i][j]=wi[i][j-1]*gp[i]%MOD;
wii[i][j]=wii[i][j-1]*gpi[i]%MOD;
}
}
}
int main()
{
prework();
scanf("%d%d",&n1,&n2);
for(int i=0;i<n1;i++)A[i]=read();
for(int i=0;i<n2;i++)B[i]=read();
scanf("%d",&m);m++;
calc_ans();
return 0;
}
2 条评论
XZYQvQ · 2018年7月8日 10:31 下午
Do you like van♂游戏?
boshi · 2018年7月9日 1:11 下午
I like VAN 游 Shi*