题目链接

因为 $n$这个值是最大的,前缀最大不用考虑 $n$后面的,后缀最大不用考虑 $n$前面的

设 $f(i, j)$表示 $i$个数的排列,有 $j$个数字是前缀最大的方案数

考虑枚举最小的数字放在哪个位置,放在第一个位置则等于 $f(i – 1, j – 1)$。放在别的位置则等于 $f(i – 1, j)$,共有 $i – 1$个 “别的位置”

所以:

$$f(i, j) = f(i – 1, j – 1) + (i – 1) \times f(i – 1, j)$$

这就是第一类斯特林数 $S _ 1(i, j)$

其原因是:有 $j$个前缀最大,相当于把 $i$个数的排列分成了 $j$段,第 $k$段为 $[$第 $k$个前缀最大 $,$第 $k + 1$个前缀最大 $)$

答案 $Ans = \sum _ {i = 1} ^ n S _ 1(i – 1, a – 1) \times S _ 1(n – i, b – 1) \times C _ {n – 1} ^ {i – 1}$

也就是选择 $i – 1$个数字放到 $n$前面,剩下的放到 $n$后面这样

$n$前面的结成 $a – 1$个环,$n$后面的结成 $b – 1$个环,相当于一共用 $n – 1$个数字结成 $a + b – 2$个环

当然这 $a + b – 2$个环中有 $a – 1$个是正向排列的,$b – 1$个是反向排列的,所以有:

$Ans = S _ 1(n – 1, a + b – 2) \times C _ {a + b – 2} ^ {b – 1}$

然后因为第一类斯特林数有生成函数:

$$\prod _ {i = 0} ^ {n – 1} (x + i)$$

这个生成函数的 $k$次项的系数就是 $S _ 1(n, k)$

这个用分治 NTT 求就行了

具体看代码吧

#include <bits/stdc++.h>

#define NS (262144)
#define LGS (18)
#define MOD (998244353)
#define G (3)

#define pls(a, b) ((a) + (b) < MOD ? (a) + (b) : (a) + (b) - MOD)
#define mns(a, b) ((a) - (b) < 0 ? (a) - (b) + MOD : (a) - (b))
#define mul(a, b) (1ll * (a) * (b) % MOD)
#define Inv(a) (qpow((a), MOD - 2))

using namespace std;

template<typename _Tp> inline void IN(_Tp& dig)
{
    char c; bool flag = 0; dig = 0;
    while (c = getchar(), !isdigit(c)) if (c == '-') flag = 1;
    while (isdigit(c)) dig = dig * 10 + c - '0', c = getchar();
    if (flag) dig = -dig;
}

int qpow(int a, int b)
{
    int res = 1;
    while (b)
    {
        if (b & 1) res = mul(res, a);
        a = mul(a, a), b >>= 1;
    }
    return res;
}

int n, A, B;

int rev[NS];

struct poly
{
    int d[NS], N, bs;
    int& operator [] (const int a) {return d[a];}
    void resize(int s)
    {
        int tmp = N;
        N = 1, bs = 0;
        while (N < s) N <<= 1, bs++;
        for (int i = tmp; i < N; i += 1) d[i] = 0;
    }
    void ntt(int t)
    {
        for (int i = 1; i < N; i += 1)
        {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bs - 1));
            if (i < rev[i]) swap(d[i], d[rev[i]]);
        }
        for (int l = 1; l < N; l <<= 1)
        {
            int dt = qpow(G, (MOD - 1) / (l << 1));
            if (t == -1) dt = Inv(dt);
            for (int i = 0; i < N; i += (l << 1))
            {
                int g = 1, t1, t2;
                for (int j = i; j < i + l; j += 1, g = mul(g, dt))
                {
                    t1 = d[j], t2 = mul(g, d[j + l]);
                    d[j] = pls(t1, t2), d[j + l] = mns(t1, t2);
                }
            }
        }
        if (t == -1)
        {
            int inv = Inv(N);
            for (int i = 0; i < N; i += 1) d[i] = mul(d[i], inv);
        }
    }
    void operator *= (poly &oth)
    {
        for (int i = 0; i < N; i += 1) d[i] = mul(d[i], oth[i]);
    }
} P[LGS];

stack<int> rub;

int Binary(int l, int r)
{
    if (l == r)
    {
        int a = rub.top(); rub.pop();
        P[a].resize(2), P[a][0] = l, P[a][1] = 1;
        return a;
    }
    int mid = (l + r) >> 1;
    int a = Binary(l, mid), b = Binary(mid + 1, r);
    P[a].resize(r - l + 2), P[b].resize(r - l + 2);
    P[a].ntt(1), P[b].ntt(1), P[a] *= P[b], P[a].ntt(-1), rub.push(b);
    return a;
}

int C(int a, int b)
{
    int x = 1, y = 1;
    for (int i = a - b + 1; i <= a; i += 1) x = mul(x, i);
    for (int i = 1; i <= b; i += 1) y = mul(y, i);
    return mul(x, Inv(y));
}

int main(int argc, char const* argv[])
{
    IN(n), IN(A), IN(B), n--;
    if (!A || !B || A + B - 2 > n) puts("0"), exit(0);
    if (!n) puts("1"), exit(0);
    for (int i = 0; i < LGS; i += 1) rub.push(i);
    int a = Binary(0, n - 1);
    printf("%lld\n", mul(P[a][A + B - 2], C(A + B - 2, B - 1)));
    return 0;
}
分类: 文章

Remmina

No puzzle that couldn't be solved.

5 条评论

boshi · 2019年1月27日 12:40 下午

有一种 $O(n\log n)$的倍增算法,比你的快到不知道哪里去了。

发表回复

Avatar placeholder

您的邮箱地址不会被公开。 必填项已用 * 标注