原来传说中的 MTT 就是分系数 FFT 啊。。。
假如我们要求 $(A \times B) \mod MOD$
做法就是设一个整数系数 $M$
$A\times B = (\lfloor \frac A M \rfloor \times M + A \mod M) \times (\lfloor \frac B M \rfloor \times M + B \mod M)$
设 $\lfloor \frac A M \rfloor$为多项式 $a _ 1$,$A \mod M$为多项式 $b _ 1$
$\lfloor \frac B M \rfloor$为多项式 $a _ 2$,$B \mod M$为多项式 $b _ 2$
则:
$A \times B$
$= (a _ 1 \times M + b _1) \times (a _ 2 \times M + b _ 2)$
$= M ^ 2 \times (a _ 1 \times a _ 2) + M \times (a _ 1 \times b _ 2 + a _ 2 \times b _ 1) + b _ 1 \times b _ 2$
如果 $M$取 $\sqrt {MOD}$,那么 $a _ 1, b _ 1, a _ 2, b _ 2$的每一项的值都是 $\leq M$的,即 $\leq \sqrt {MOD}$的,因此两两卷积起来的每一项都 $\leq MOD$,就可以保证不爆精度
而且这种做法精度要求非常高,得用 long double
但是 $M$是一个很奇怪的东西,一般的话取 $2 ^ {15} = 32768$比较好,平方不会爆,而且似乎精度误差小(我也不太知道为什么,也许和二进制有关),实测本题取 $32000$和 $33000$都不行(但是 $32767$可以)
如果时限要求不紧随便跑 4 次 DFT 再跑 4 次 IDFT 共 8 次 FFT 就行了
如果卡实现的话有优化,可以做到一共只要 4 次 FFT,请自行百度(雾)
另外如果实在是怕被卡精度可以预处理出所有单位根。
代码:
#include <bits/stdc++.h>
#define NS (100005)
using namespace std;
typedef long long LL;
typedef complex<long double> cpx;
const long double PI = acos(-1);
const int M = (1 << 15);
template<typename _Tp> inline void IN(_Tp& dig)
{
char c; bool flag = 0; dig = 0;
while (c = getchar(), !isdigit(c)) if (c == '-') flag = 1;
while (isdigit(c)) dig = dig * 10 + c - '0', c = getchar();
if (flag) dig = -dig;
}
int rev[NS << 2];
struct Poly
{
int len, bs; cpx A[NS << 2];
cpx& operator [] (const int a) {return A[a];}
void resize(int a)
{
len = 1, bs = 0;
while (len < a) len <<= 1, bs++;
}
void FFT(int f = 1)
{
for (int i = 0; i < len; i += 1)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bs - 1));
if (i < rev[i]) swap(A[i], A[rev[i]]);
}
for (int l = 1; l < len; l <<= 1)
{
cpx p(cos(PI / l), sin(PI / l) * f);
for (int i = 0; i < len; i += (l << 1))
{
cpx w(1, 0), t1, t2;
for (int j = i; j < i + l; j += 1, w *= p)
{
t1 = A[j], t2 = w * A[j + l];
A[j] = t1 + t2, A[j + l] = t1 - t2;
}
}
}
if (f == -1) for (int i = 0; i < len; i += 1) A[i] /= len;
}
void operator *= (Poly &oth)
{
for (int i = 0; i < len; i += 1) A[i] *= oth[i];
}
} a1, b1, a2, b2, aa, ab, ba, bb;
int n, m, MOD, F1[NS], F2[NS];
void MTT()
{
a1.resize(n + m + 1), b1.resize(n + m + 1);
a2.resize(n + m + 1), b2.resize(n + m + 1);
for (int i = 0; i <= n; i += 1)
a1[i].real(F1[i] / M), b1[i].real(F1[i] % M);
for (int i = 0; i <= m; i += 1)
a2[i].real(F2[i] / M), b2[i].real(F2[i] % M);
a1.FFT(), b1.FFT(), a2.FFT(), b2.FFT();
aa = a1, aa *= a2, ab = a1, ab *= b2, ba = b1, ba *= a2, bb = b1, bb *= b2;
aa.FFT(-1), ab.FFT(-1), ba.FFT(-1), bb.FFT(-1);
}
int main(int argc, char const* argv[])
{
IN(n), IN(m), IN(MOD);
for (int i = 0; i <= n; i += 1) IN(F1[i]);
for (int i = 0; i <= m; i += 1) IN(F2[i]);
MTT();
for (int i = 0; i <= n + m; i += 1)
{
LL a = aa[i].real() + 0.1;
a = a * M % MOD * M % MOD;
a = (a + (LL)(ab[i].real() + 0.5) * M % MOD) % MOD;
a = (a + (LL)(ba[i].real() + 0.5) * M % MOD) % MOD;
a = (a + (LL)(bb[i].real() + 0.5)) % MOD;
printf("%lld ", a);
}
putchar(10);
return 0;
}
0 条评论