因为 $n$这个值是最大的,前缀最大不用考虑 $n$后面的,后缀最大不用考虑 $n$前面的
设 $f(i, j)$表示 $i$个数的排列,有 $j$个数字是前缀最大的方案数
考虑枚举最小的数字放在哪个位置,放在第一个位置则等于 $f(i – 1, j – 1)$。放在别的位置则等于 $f(i – 1, j)$,共有 $i – 1$个 “别的位置”
所以:
$$f(i, j) = f(i – 1, j – 1) + (i – 1) \times f(i – 1, j)$$
这就是第一类斯特林数 $S _ 1(i, j)$
其原因是:有 $j$个前缀最大,相当于把 $i$个数的排列分成了 $j$段,第 $k$段为 $[$第 $k$个前缀最大 $,$第 $k + 1$个前缀最大 $)$
答案 $Ans = \sum _ {i = 1} ^ n S _ 1(i – 1, a – 1) \times S _ 1(n – i, b – 1) \times C _ {n – 1} ^ {i – 1}$
也就是选择 $i – 1$个数字放到 $n$前面,剩下的放到 $n$后面这样
$n$前面的结成 $a – 1$个环,$n$后面的结成 $b – 1$个环,相当于一共用 $n – 1$个数字结成 $a + b – 2$个环
当然这 $a + b – 2$个环中有 $a – 1$个是正向排列的,$b – 1$个是反向排列的,所以有:
$Ans = S _ 1(n – 1, a + b – 2) \times C _ {a + b – 2} ^ {b – 1}$
然后因为第一类斯特林数有生成函数:
$$\prod _ {i = 0} ^ {n – 1} (x + i)$$
这个生成函数的 $k$次项的系数就是 $S _ 1(n, k)$
这个用分治 NTT 求就行了
具体看代码吧
#include <bits/stdc++.h>
#define NS (262144)
#define LGS (18)
#define MOD (998244353)
#define G (3)
#define pls(a, b) ((a) + (b) < MOD ? (a) + (b) : (a) + (b) - MOD)
#define mns(a, b) ((a) - (b) < 0 ? (a) - (b) + MOD : (a) - (b))
#define mul(a, b) (1ll * (a) * (b) % MOD)
#define Inv(a) (qpow((a), MOD - 2))
using namespace std;
template<typename _Tp> inline void IN(_Tp& dig)
{
char c; bool flag = 0; dig = 0;
while (c = getchar(), !isdigit(c)) if (c == '-') flag = 1;
while (isdigit(c)) dig = dig * 10 + c - '0', c = getchar();
if (flag) dig = -dig;
}
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = mul(res, a);
a = mul(a, a), b >>= 1;
}
return res;
}
int n, A, B;
int rev[NS];
struct poly
{
int d[NS], N, bs;
int& operator [] (const int a) {return d[a];}
void resize(int s)
{
int tmp = N;
N = 1, bs = 0;
while (N < s) N <<= 1, bs++;
for (int i = tmp; i < N; i += 1) d[i] = 0;
}
void ntt(int t)
{
for (int i = 1; i < N; i += 1)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bs - 1));
if (i < rev[i]) swap(d[i], d[rev[i]]);
}
for (int l = 1; l < N; l <<= 1)
{
int dt = qpow(G, (MOD - 1) / (l << 1));
if (t == -1) dt = Inv(dt);
for (int i = 0; i < N; i += (l << 1))
{
int g = 1, t1, t2;
for (int j = i; j < i + l; j += 1, g = mul(g, dt))
{
t1 = d[j], t2 = mul(g, d[j + l]);
d[j] = pls(t1, t2), d[j + l] = mns(t1, t2);
}
}
}
if (t == -1)
{
int inv = Inv(N);
for (int i = 0; i < N; i += 1) d[i] = mul(d[i], inv);
}
}
void operator *= (poly &oth)
{
for (int i = 0; i < N; i += 1) d[i] = mul(d[i], oth[i]);
}
} P[LGS];
stack<int> rub;
int Binary(int l, int r)
{
if (l == r)
{
int a = rub.top(); rub.pop();
P[a].resize(2), P[a][0] = l, P[a][1] = 1;
return a;
}
int mid = (l + r) >> 1;
int a = Binary(l, mid), b = Binary(mid + 1, r);
P[a].resize(r - l + 2), P[b].resize(r - l + 2);
P[a].ntt(1), P[b].ntt(1), P[a] *= P[b], P[a].ntt(-1), rub.push(b);
return a;
}
int C(int a, int b)
{
int x = 1, y = 1;
for (int i = a - b + 1; i <= a; i += 1) x = mul(x, i);
for (int i = 1; i <= b; i += 1) y = mul(y, i);
return mul(x, Inv(y));
}
int main(int argc, char const* argv[])
{
IN(n), IN(A), IN(B), n--;
if (!A || !B || A + B - 2 > n) puts("0"), exit(0);
if (!n) puts("1"), exit(0);
for (int i = 0; i < LGS; i += 1) rub.push(i);
int a = Binary(0, n - 1);
printf("%lld\n", mul(P[a][A + B - 2], C(A + B - 2, B - 1)));
return 0;
}
5 条评论
boshi · 2019年1月27日 12:40 下午
有一种 $O(n\log n)$的倍增算法,比你的快到不知道哪里去了。
Remmina · 2019年1月27日 8:24 下午
你行你写啊 QvQ
boshi · 2019年1月30日 9:55 下午
https://www.luogu.org/recordnew/show/15932564
比你快哦
Remmina · 2019年1月30日 10:49 下午
我是要你写博客,不是要你装 X
boshi · 2019年1月31日 8:24 下午
行行行