题目大意
求 $n$元组 $(A_1,A_2,…,A_n)$的方案数,满足
- $\sum{a_i} = m$
- $2a_i \leq a_{i-1} + a_{i + 1} \quad(2\leq i \leq n – 1)$
$n,m \leq 10^5$
心路历程
(想看题解的可以跳过这段)
考场上化了半天的式子,$c_i$是 $a_i$的二阶差分,有:
$$\sum_{i=1}^n \frac {i(i+1)}2 c_i = m$$
到这里问题有二,一是你很难确定 $c_i$的具体范围 (实际上是与 $\sum c_i$有关,还是没有完全脱离其它 $c_i$的限制),二是假设 $c_i$是非负整数,你还是不知道如何求 $c_i$的方案数。
然后就打算从别的方向想,以为是什么暴力 $dp$的化简,但是无果。
然后就弃疗了,最后一个小时都在罚坐。
题解
事实上一个下凸的序列,总是有一个最低点,不妨从这里入手。
假设最低点下标为 $i$(相同高度的点取最左边的),我们会发现,因为 $a_{i-1}>a_i$的缘故,我们总有
$$a_{i-k} \geq a_i + k\quad (1 \leq k \leq i – 1)$$
我们假设 $b_i$为 $a_i$的最小值
举个栗子,假设 $i=3$,$n=5$,我们有 $(b_1,b_2,b_3,b_4,b_5)=(2, 1, 0, 0, 0)$
注意到,所有以 $i$为最低点的合法 $a$序列中,我们都是以对应的 $b$序列通过加一些数生成的。
具体怎么加数,我们有以下三种操作
- 给每一位 $+1$,即总体 $+n$。
- 对于 $i$左边的数,选择最前面的 $k$个数,分别加上 $(k,k-1,…,2,1)$
- 对于 $i$右边的数,选择最后面的 $k$个数,分别加上 $(1,2,…,k-1,k)$
以上三种操作都可以进行任意次。
那么方案数就是操作的种类数 (不同的操作当且仅当整个操作序列按一定顺序排序后有一处操作不相同)。
因为我们只关注的是序列 $a$的和等于 $m$,所以我们只需要关注每个操作造成的序列和的变化。
因为操作数是无限的,我们可以转化为一个完全背包的模型。
对于固定的 $i$,注意到第二第三种操作的 $k$最大就是 $O(\sqrt m)$,否则序列和变化会超过 $m$,因此总共物品数就是 $O(\sqrt m)$,所以我们只需要 $O(m\sqrt m)$跑一遍完全背包即可。
容易发现,$i$只有在前 $O(\sqrt m)$个数中才能成立,否则 $b_i$序列和就会大于 $m$。
对于 $i-1$到 $i$,我们发现我们的操作只是减去一个物品和加上一个物品,因此可以 $O(m)$转移
总复杂度 $O(m\sqrt m)$
#include<bits/stdc++.h>
#define Re register
#define fo(i, a, b) for (int i = (a); i <= (b); ++i)
#define fd(i, a, b) for (int i = (a); i >= (b); --i)
#define edge(i, u) for (int i = head[u], v = e[i].v; i; i = e[i].nxt, v = e[i].v)
#define pb push_back
#define F first
#define S second
#define ll long long
#define inf LLONG_MAX
#define mp std::make_pair
#define mod 1000000007
#define lowbit(x) (x & -x)
#define N 500005
#define clr(arr) memset(arr, 0, sizeof arr)
#define bset std::bitset<N>
#define idx(i, u) ((i) * (nn) + (u))
#define eps (1e-8)
inline int read ()
{
int x = 0; bool flag = 0; char ch = getchar();
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') flag = 1, ch = getchar();
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
if (flag) x = -x; return x;
}
inline ll pow (ll x, int y)
{
ll ret = 1;
while (y)
{
if (y & 1) ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
ll n, m, f[N], c[N], p, ans;
int main ()
{
n = read(); m = read();
f[0] = 1;
fo (i, n, m)
f[i] = (f[i] + f[i - n]) % mod;
fo (i, 1, n)
c[i] = c[i - 1] + i;
fo (i, 1, n - 1)
{
if (c[i] > m) break;
fo (j, c[i], m)
f[j] = (f[j] + f[j - c[i]]) % mod;
}
p = m;
ans = f[p];
fo (i, 1, n - 1)
{
p -= i;
if (p < 0) break;
ll now = c[n - i];
fd (j, m, now)
f[j] = (f[j] - f[j - now]) % mod;
fo (j, c[i], m)
f[j] = (f[j] + f[j - c[i]]) % mod;
ans = (ans + f[p]) % mod;
}
printf("%lld", (ans + mod) % mod);
}
0 条评论