如果限制条件只有 1,2,4,是个人基本都会做。
问题在3,不妨考虑 $x>y$的情况(方便去掉绝对值),显然另一种情况 $y>x$是对称的,我们只需要考虑现在的这种情况即可。
因为涉及到减法,有一些退位的问题就很烦,比赛的时候过的人也少,队里就没人做这题了。
我们把 k 给二进制展开成 $(k_t…k_1)_2$
我们从高位向低位 dfs,每次枚举两个数的第 $u$位,分别为 $i, j$,记 $nk_u = i – j$(表示 $nk$的第 $u$位二进制数),显然 $nk_u$可能是 $-1, 0,1$其中的一个,我们的目标是让 $nk\leq k$。
我们会发现,如果做到第 $u$位的时候,若前面的数 $nk – k \geq 2$的话,我们就算把 $nk$剩下的位数全部置为 $-1$,最多只能使前面的数 $-1$,所以这样的情况对答案是没有贡献的。
因此我们只要记录前面的数的 $nk – k$是多少 (代码里这个值为 $limk$),然后分类讨论转移即可 (具体看代码)。
自己不会做最大的问题是没想到每一位 $nk-k\leq 1$必定成立的性质,长脑子了。
还有一个小小的 trick 是可以每次都把 $dp$数组置 $-1$(前提是数组比较小),这样可以记录更多状态,减少无用搜索数量
#include<bits/stdc++.h>
#define Re register
#define fo(i, a, b) for (Re int i = (a); i <= (b); ++i)
#define fd(i, a, b) for (Re int i = (a); i >= (b); --i)
#define edge(i, u) for (Re int i = head[u], v = e[i].v; i; i = e[i].nxt, v = e[i].v)
//#define pb push_back
#define F first
#define S second
#define ll long long
#define inf 1000000007
#define mp std::make_pair
#define eps 1e-4
#define mod 10086
#define lowbit(x) (x & -x)
#define N 505
#define M 2005
#define clr(arr) memset(arr, 0, sizeof arr)
#define bset std::bitset<N>
#define pi std::pair<int, int>
#define ls t[u].l
#define rs t[u].r
#define pls t[pu].l
#define prs t[pu].r
ll read ()
{
Re char ch = getchar();
Re ll ret = 0;
while (!isdigit(ch)) ch = getchar();
while (isdigit(ch)) {ret = ret * 10 + ch - '0'; ch = getchar();}
return ret;
}
char s[N];
int tb[35], ta[35], tk[35], tw[35], pa[35], pb[35];
ll dp[35][2][2][2][3][2];
ll dfs (int u, bool lima, bool limb, int limk, bool limw, bool ab)
{
if (!u)
{
if (limk <= 0 && !ab)
{
/* int ra = 0, rb = 0;
fd (i, 4, 1) {ra |= pa[i] << i - 1;}//printf("%d", pa[i]); printf(" ");
fd (i, 4, 1) {rb |= pb[i] << i - 1;}//printf("%d", pb[i]); puts("");
printf("%d %d %d %d\n", ra, rb, ra - rb, ra ^ rb);
*/ return 1;
}
return 0;
}
if (dp[u][lima][limb][limw][limk + 1][ab] != -1)
return dp[u][lima][limb][limw][limk + 1][ab];
int upi = lima ? ta[u] : 1;
int upj = limb ? tb[u] : 1;
ll ret = 0;
fo (i, 0, upi)
fo (j, 0, upj)
{
if (ab && i < j) continue;
int k = i ^ j;
if (limw && k > tw[u]) continue;
k = i - j;
int nk;
if (limk == -1)
{
nk = -1;
}
else
if (!limk)
{
if (k > tk[u]) nk = 1;
if (k == tk[u]) nk = 0;
if (k < tk[u]) nk = -1;
}
else
{
nk = k + 2 - tk[u];
if (nk >= 2) continue;
}
// pa[u] = i; pb[u] = j;
ret += dfs(u - 1, lima && i == ta[u], limb && j == tb[u], nk, limw && ((i ^ j) == tw[u]), ab && i == j);
}
dp[u][lima][limb][limw][limk + 1][ab] = ret;
return ret;
}
ll solve (int a, int b, int k, int w)
{
memset(dp, -1, sizeof dp);
fo (i, 1, 30)
{
ta[i] = (a >> i - 1) & 1;
tb[i] = (b >> i - 1) & 1;
tk[i] = (k >> i - 1) & 1;
tw[i] = (w >> i - 1) & 1;
}
return dfs(30, 1, 1, 0, 1, 1);
}
int a, b, k, w;
int main ()
{
int T = read();
while (T--)
{
a = read(); b = read(); k = read(); w = read();
printf("%lld\n", solve(a, b, k, w) + solve(b, a, k, w) + std::min(a, b) + 1);
}
return 0;
}
0 条评论