在两年前,我学习了快速傅里叶变换。当时有许多的问题没能透彻地理解。
在半年前,我接触到了任意模数 NTT,但是当时只背了个板子,没有弄懂它的原理。
今天,隔壁机房一个高二的大佬点醒了我,我突然就明白任意模数 NTT 究竟再干啥了。
任意模数 NTT 原理
当我们处理模数为任意 $10^9$左右质数的多项式乘法时,我们往往将每一次 dft 拆分为两次元素大小在 $10^5$级别的 $dft$。
这样,我们总共需要 $8$次 $dft$。
然而,有一个技巧却可以允许我们使用 $4$次 $dft$做到原先的 $8$次 dft 完成的事情。这个技巧人称 $mtt$,它的精髓在于,将两次 $dft$合并为一次完成。
基础知识
默认大家对 fft,dft,ntt 的原理非常熟悉,并且已经学习了有关复数和三角函数的相关知识。
我们先定义一些东西。
我们默认所有的数列的长度、函数的项数等,统统为 $n$。
定义 $dft$是一个作用于函数 $f$的算子,其中 $F=dft(f)$得到的是一个数列 $F$,其中 $F_k=f(\omega_n^k)$。如果我们把多项式函数和数列同等对待的话,我们也可以认为 $F_k=\sum_{i=0}^{n-1}f_i\omega_n^{ki}$
定义 $reverse$是一个作用于数列 $s$的算子,它的结果是 $s$翻转之后的数列,如 $rev({1,2,3})={3,2,1}$
定义 $idft$是一个作用于数列 $F$的算子,它是 $dft$的逆运算。并且,如果我们把多项式函数和数列同等对待的话,$f=idft(F)$,则 $f_k=\frac{1}{n}\sum_{i=0}^{n-1}F_i\omega_n^{-ki}$
定义 $conj(z)$为 $z$的共轭复数。
如果将任意一个算子、算符作用于一个数列,得到的则是一个新的数列,每一项为原数列对应项在算子、算符作用下的结果。
算法原理
假设我们需要对于实多项式函数$A$和 $B$分别做 $dft$,即,求出
$$
\begin{aligned}
F_i & =A(\omega_n^i)=\sum_{k=0}^{n-1}a_k\omega_n^{ki}\\
G_i & =B(\omega_n^i)=\sum_{k=0}^{n-1}b_k\omega_n^{ki}
\end{aligned}
$$
怎么一次完成呢?
有公式如下:
$$
\begin{aligned}
dft(A+iB) _ k & =\sum _ {i=0}^{n-1}(a_i+ib_i)\omega _ n^{ki}=A(\omega _ n^k)+iB(\omega _ n^k)\\
dft(A-iB) _ k & =\sum _ {i=0}^{n-1}(a_i-ib_i)\omega _ n^{ki}=A(\omega _ n^k)-iB(\omega _ n^k)=conj(dft(A+iB) _ {n-k})
\end{aligned}
$$
大家对于上述公式的最后一个等号可能不是很能理解。这里待会会给出证明。
假设上述公式是正确的,我们只需要计算 $dft(A+iB)$,即可还原出 $dft(A-iB)=conj(reverse(dft(A+iB)))$,而:
$$
\begin{aligned}
dft(A)_k & =\frac{dft(A+iB)+dft(A-iB)}{2}\\
\
dft(B)_k & =\frac{dft(A+iB)-dft(A-iB)}{2i}
\end{aligned}
$$
这样,我们就通过一次 $dft$完成之前两次 dft 的工作了。
下面给出上面那个等号的证明。
$$
\begin{aligned}
dft(A+iB) _ {n-k} & =\sum _ {i=0}^{n-1}(a _ i+ib _ i)\omega _ n^{-ki}\\
& =\sum _ {i=0}^{n-1}(a _ i+ib _ i)(\cos \theta-i\sin \theta)\\
& =\sum _ {i=0}^{n-1}(a _ i\cos\theta+b _ i\sin\theta)-i(a _ i\sin\theta-b _ i\cos\theta)\\
& =\sum _ {i=0}^{n-1}conj((a _ i\cos\theta+b _ i\sin\theta)+i(a _ i\sin\theta-b _ i\cos\theta))\\
& =\sum _ {i=0}^{n-1}conj((a _ i-ib _ i)(cos\theta+i\sin\theta))\\
& =\sum _ {i=0}^{n-1}conj((a _ i-ib _ i)\omega _ n^{ki})\\
& =conj(dft(A-iB) _ k)
\end{aligned}
$$
注意,这里的多项式函数的系数必须是实数,因为推导过程中默认 $a_i,b_i$都是实数。
算法实现
在高效地实现 $mtt$时,往往会用到另一个技巧,即 $idft(s)=\frac{1}{n}dft(rev(s))$
另外,运用到一些其他的七里八里的技巧,可以让你的代码更短,但是不会太易懂。
// luogu-judger-enable-o2
#include <bits/stdc++.h>
#define ML 262149
#define pi (acos(-1))
#define M 32768ll
using namespace std;
typedef long long ll;
typedef long double ldb;
void read(ll& x)
{
x = 0; char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) x = x*10+c-'0', c = getchar();
}
ll mod;
namespace fft
{
struct Z
{
ldb r, i;
Z (const ldb &r0 = 0, const ldb &i0 = 0) : r(r0), i(i0) {}
Z operator + (const Z& t) const {return Z(r+t.r, i+t.i);}
Z operator - (const Z& t) const {return Z(r-t.r, i-t.i);}
Z operator * (const Z& t) const {return Z(r*t.r-i*t.i, r*t.i+i*t.r);}
Z conj() const {return Z(r, -i);}
void operator /= (const ldb& t) {r /= t, i /= t;}
};
struct Fourier
{
int n, bit, rev[ML];
void init(int x)
{
n = 1, bit = 0;
while(n < x) n <<= 1, bit++;
for(int i=1; i<n; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(bit-1));
}
void dft(Z *x, int f)
{
for(int i=0; i<n; i++)
if(i < rev[i])
swap(x[i], x[rev[i]]);
for(int w=1; w<n; w<<=1)
{
for(int i=0; i<n; i+=(w<<1))
{
for(int j=0; j<w; j++)
{
Z a = x[i+j], b = x[i+j+w] * Z(cos(pi/w*j), f*sin(pi/w*j));;
x[i+j] = a + b;
x[i+j+w] = a - b;
}
}
}
if(f == -1) for(int i=0; i<n; i++) x[i] /= n;
}
} F;
Z Xq[ML], Yq[ML], xlyl[ML], xlyh[ML], xhyl[ML], xhyh[ML];
void fast_multiply(ll *x, ll *y, ll *ret)
{
for(int i=0; i<F.n; i++)
Xq[i] = Z(x[i]>>15, x[i]&((1<<15)-1)),
Yq[i] = Z(y[i]>>15, y[i]&((1<<15)-1));
F.dft(Xq, +1), F.dft(Yq, +1);
for(int i=0; i<F.n; i++)
{
int j = (F.n-i) & (F.n-1);
Z xh = (Xq[i]+Xq[j].conj()) * Z(0.5, 0);
Z xl = (Xq[i]-Xq[j].conj()) * Z(0, -0.5);
Z yh = (Yq[i]+Yq[j].conj()) * Z(0.5, 0);
Z yl = (Yq[i]-Yq[j].conj()) * Z(0, -0.5);
xhyh[j] = xh*yh, xhyl[j] = xh*yl, xlyh[j] = xl*yh, xlyl[j] = xl*yl;
}
for(int i=0; i<F.n; i++)
Xq[i] = xhyh[i] + xhyl[i] * Z(0, 1),
Yq[i] = xlyh[i] + xlyl[i] * Z(0, 1);
F.dft(Xq, +1), F.dft(Yq, +1);
for(int i=0; i<F.n; i++)
{
ll xhyh = ll(Xq[i].r/F.n + 0.5) % mod;
ll xhyl = ll(Xq[i].i/F.n + 0.5) % mod;
ll xlyh = ll(Yq[i].r/F.n + 0.5) % mod;
ll xlyl = ll(Yq[i].i/F.n + 0.5) % mod;
ret[i] = ((xhyh<<30) + (xhyl<<15) + (xlyh<<15) + (xlyl)) % mod;
}
}
}
ll a[ML], b[ML], c[ML], n, m;
int main()
{
read(n); read(m); read(mod);
fft::F.init(n+m+2);
for(int i=0; i<=n; i++) read(a[i]);
for(int i=0; i<=m; i++) read(b[i]);
fft::fast_multiply(a, b, c);
for(int i=0; i<=n+m; i++) printf("%lld ", c[i]); putchar('\n');
return 0;
}
0 条评论