题解
这道题比较特别的地方就在于求的是合法数的平方和
其实也不是太难维护
当我们做完第 $u$位的时候此位为 $i$,假设我们还有 $cnt$个合法的数 $p_1…p_{cnt}$,那么它们的贡献就是 $(10^ui+p_1)^2+…+(10^ui+p_{cnt})^2$,化简一下就是 $cnt\times 10^{2u}i^2+2\times 10^ui\sum p_i+\sum p_i^2$
所以我们维护一下每个状态的合法状态数,合法数的和,合法数的平方和,然后记搜一下就行了
$dp[u][sum][div]$表示当前在第 $u$位,数字和模 $7$为 $sum$,当前数模 $7$为 $div$,转移看代码吧
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<ctype.h>
#define Re register
#define fo(i, a, b) for (Re int i = (a); i <= (b); ++i)
#define fd(i, a, b) for (Re int i = (a); i >= (b); --i)
#define edge(i, u) for (Re int i = head[u], v = e[i].v; i; i = e[i].nxt, v = e[i].v)
#define pb push_back
#define F first
#define S second
#define ll long long
#define inf 1000000007
#define mp std::make_pair
#define eps 1e-4
#define mod 1000000007
#define lowbit(x) (x & -x)
#define N 10005
#define cl(arr) memset(arr, 0, sizeof arr)
#define bset std::bitset<N>
#define pi std::pair<int, int>
inline void read (ll &x)
{
x = 0;
Re bool flag = 0;
Re char ch = getchar();
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') flag = 1, ch = getchar();
while (isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
if (flag) x = -x;
}
struct node {
ll cnt, sum, sq;
};
node dp[20][8][8];
ll a[N], tot, n, m, po[N];
inline node dfs (ll u, int sum, bool limit, int div)
{
if (!u)
{
node tmp;
tmp.cnt = (sum != 0) && (div != 0);
tmp.sum = tmp.sq = 0;
return tmp;
}
if (dp[u][sum][div].cnt != -1 && !limit) return dp[u][sum][div];
int up = limit ? a[u] : 9;
node ret = (node) {0, 0, 0};
fo (i, 0, up)
{
if (i == 7) continue;
node tmp = dfs(u - 1, (sum + i) % 7, limit && i == a[u], (div * 10 + i) % 7);
(ret.cnt += tmp.cnt) %= mod;
(ret.sum += tmp.cnt * po[u - 1] % mod * i + tmp.sum) %= mod;
(ret.sq += tmp.cnt * po[u - 1] % mod * po[u - 1] % mod * i * i + 2 * po[u - 1] % mod * tmp.sum % mod * i % mod + tmp.sq) %= mod;
}
if (!limit) dp[u][sum][div] = ret;
return ret;
}
inline ll solve (ll x)
{
if (x < 0) return -1;
tot = 0;
while (x)
{
a[++tot] = x % 10;
x /= 10;
}
return dfs(tot, 0, 1, 0).sq;
}
main ()
{
po[0] = 1;
fo (i, 1, 18) po[i] = po[i - 1] * 10 % mod;
memset(dp, -1, sizeof dp);
ll T;
read(T);
while (T--)
{
read(n); read(m);
printf("%lld\n", (solve(m) - solve(n - 1) + mod) % mod);
}
return 0;
}
0 条评论