发现并不需要求出期望,之需要求出所有方案的和即可。考虑 $\rm{DP}$ :
- 设 $f(i,j)$ 表示选出的 $m$ 张牌中有 $i$ 张是强化牌,一共打出了 $j$ 张强化牌时所有方案的倍率和。
-
设 $g(i,j)$ 表示选出的 $m$ 张牌中有 $i$ 张是攻击牌,一共打出了 $j$ 张攻击牌时所有方案的伤害和。
然后打牌的策略很显然是:
- 如果 $m$ 张牌中有 $i$ 张强化牌,$i<k-1$ ,那么将所有强化牌打完,然后打出前 $k-i$ 大的攻击牌。
- 如果 $m$ 张牌中有 $i$ 张强化牌,$i\geq k-1$ ,那么打出前 $k-1$ 大的强化牌,然后打出最大的攻击牌。
上面的 $\rm{DP}$ 不好转移。
考虑设 $dp_1(i,j)$ 表示一共选了 $i$ 张强化牌,其中最小的一张在所有强化牌中为第 $j$ 小时所有方案的倍率和,同样设 $dp_2(i,j)$ 表示一共选了 $i$ 张攻击牌,其中最小的一张在所有攻击牌中为第 $j$ 小时所有方案的伤害和。
先考虑 $dp_1$ 的转移,假设我们是从大往小选牌,上一次如果选到了第 $k$ 小,那么对于所有的 $k>j$ ,都可以用 $w_j\times dp_{1}(i-1,k)$ 来更新 $dp_1(i,j)$ :
$$
dp_1(i,j)=w_j\times \sum_{k=j+1}^{n} dp_1(i-1,k)
$$
注意这里是将 $w_1$ 数组从小到大排了序的。
上面的 $\rm{DP}$ 是 $O(n^3)$ 的,用前缀和优化可以做到 $O(n^2)$ 。
然后考虑 $dp_2$ 的转移,一样是从大到小选牌,如果上一次选到了第 $k$ 小,那么对于所有的 $k>j$ ,都可以用 $dp_2(i-1,k)$ 来更新 $dp_2(i,j)$ 。然后显然 $dp_{2}(i,j)$ 还有 $w_j$ 的贡献,从 $n-j$ 这些牌里面选 $i$ 张牌的方案数一共有 ${n-j\choose i-1}$ 种,每一种方案现在的伤害和都需要加上 $w_j$ ,所以总共加上 ${n-j\choose i-1}\times w_j$ :
$$
dp_2(i,j)={n-j\choose i-1}\times w_j+\sum_{k=j+1}^{n} dp_2(i-1,k)
$$
用前缀和优化一样能做到 $O(n^2)$ 。
最后统计答案,枚举 $i$ ,表示 $m$ 张牌中有 $i$ 张强化牌。
对于 $i<k-1$ ,答案为 $f(i,i)\times g(m-i,k-i)$ ,否则答案为 $f(i,k-1)\times g(m-i,1)$ ,求个和即可。
还需要考虑 $f,g$ 与 $dp_1,dp_2$ 的关系。
对于 $f(i,j)$ ,显然有 $j$ 张是需要打出的,$i-j$ 张不需要打出,一定要满足打出的牌中最小的比不打出的牌中最大的要大,枚举打出的牌中最小的牌的排名 $k$ ,然后用组合数计算即可。
$g(i,j)$ 的计算方式如法炮制。
Code:
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=1.5e3+5;
const int mod=998244353;
int C[N<<1][N<<1];
int T,n,m,k,w1[N],w2[N],dp1[N][N],dp2[N][N],sum1[N],sum2[N];
template <typename _Tp> inline void IN(_Tp&x) {
char ch;bool flag=0;x=0;
while(ch=getchar(),!isdigit(ch)) if(ch=='-') flag=1;
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
if(flag) x=-x;
}
inline int modpow(int x,int y,int res=1) {
for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) res=1ll*res*x%mod;
return res;
}
inline int F(int x,int y,int res=0) {
if(x<y) return 0;
if(!y) return C[n][x];
for(int i=x-y+1;i<=n-y+1;++i) (res+=1ll*dp1[y][i]*C[i-1][x-y]%mod)%=mod;
return res;
}
inline int G(int x,int y,int res=0) {
if(x<y) return 0;
for(int i=x-y+1;i<=n-y+1;++i) (res+=1ll*dp2[y][i]*C[i-1][x-y]%mod)%=mod;
return res;
}
inline void solve() {
IN(n),IN(m),IN(k);
for(int i=1;i<=n;++i) IN(w1[i]);
for(int i=1;i<=n;++i) IN(w2[i]);
sort(w1+1,w1+1+n),
sort(w2+1,w2+1+n);
for(int i=1;i<=n;++i) {
for(int j=1;j<=n;++j) dp1[i][j]=dp2[i][j]=0;
sum1[i]=(sum1[i-1]+w1[i])%mod,dp1[1][i]=w1[i];
sum2[i]=(sum2[i-1]+w2[i])%mod,dp2[1][i]=w2[i];
}
for(int i=2;i<=n;++i) {
for(int j=1;j<=n-i+1;++j)
dp1[i][j]=1ll*w1[j]*(sum1[n]-sum1[j]+mod)%mod,
dp2[i][j]=(1ll*w2[j]*C[n-j][i-1]%mod+(sum2[n]-sum2[j]+mod)%mod)%mod;
for(int j=1;j<=n;++j)
sum1[j]=(sum1[j-1]+dp1[i][j])%mod,
sum2[j]=(sum2[j-1]+dp2[i][j])%mod;
}
int ans=0;
for(int i=0;i<m;++i)
if(i<k-1) (ans+=1ll*F(i,i)*G(m-i,k-i)%mod)%=mod;
else (ans+=1ll*F(i,k-1)*G(m-i,1)%mod)%=mod;
printf("%d\n",ans);
}
int main() {
C[0][0]=1;
for(int i=1,limit=3000;i<=limit;++i) {
C[i][0]=C[i][i]=1;
for(int j=1;j<i;++j) C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
IN(T);
while(T--) solve();
return 0;
}
0 条评论