题解
这题比普通的法法塔好玩一点
首先我们手算找规律可以发现。对于一个回文串,设长度为 $n$,它的回文中心对答案的贡献是
$$2^{\frac{n+1}2}-1-\frac{n+1}2$$
也就是每对回文数选与不选-空集-连续回文串的个数
我们考虑将这些贡献分开计算。
连续回文串的个数我们显然可以跑一遍 $manachar$然后就统计出来了。
然后问题就转换成了计算每一个回文中心,有多少个回文对。
$Formally:$记回文中心为 $t$,计算有多少 $j$满足 $s[t-j] == s[t + j]$。
这里回文中心在字符上还是字符间要讨论一下,因为这关系到 $j$能不能取 $0$
至于这个怎么统计,不就是法法塔喵。
将字符为 $a$的地方记为 $1$,自己卷积自己,$i$为偶数的时候,第 $i>>1$位就是回文中心,$i$为计数的时候,第 $i>>1$位与第 $(i>>1)+1$中间那个位置就是回文中心。
字符 $b$同理。
然后就说完了
我的法法塔常数太大了最慢的点跑了 600+ms,QwQ
#include<bits/stdc++.h>
#define N 600005
#define fo(i, a, b) for (R int i = (a); i <= (b); ++i)
#define fd(i, a, b) for (R int i = (a); i >= (b); --i)
#define in inline
#define R register
#define mod 1000000007
#define ll long long
const double pi = acos(-1);
struct complex{
double real, imag;
in void conj ()
{
imag = -imag;
}
friend in complex operator * (const complex &x, const complex &y)
{
return (complex) {x.real * y.real - x.imag * y.imag, x.real * y.imag + x.imag * y.real};
}
friend in complex operator + (const complex &x, const complex &y)
{
return (complex) {x.real + y.real, x.imag + y.imag};
}
friend in complex operator - (const complex &x, const complex &y)
{
return (complex) {x.real - y.real, x.imag - y.imag};
}
}c1[N], c2[N], w[N];
in void dft(complex *c, int len)
{
int k = 0;
while ((1 << k) < len) ++k;
--k;
fo (i, 0, len)
{
int g = 0;
fo (j, 0, k)
if ((1 << j) & i) g |= (1 << k - j);
if (i < g) std::swap(c[i], c[g]);
}
for (int l = 2; l <= len; l <<= 1)
{
int mid = l >> 1;
for (complex *p = c; p != c + len; p += l)
for (int i = 0; i < mid; ++i)
{
complex tmp = w[len / l * i] * p[mid + i];
p[mid + i] = p[i] - tmp;
p[i] = p[i] + tmp;
}
}
}
inline ll pow (ll a, int b)
{
ll ret = 1;
while (b)
{
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
}
return ret;
}
int mp[5], k, a[N], n, f[N], r, cnt[N];
char s[N];
main()
{
scanf("%s", s + 1);
n = strlen(s + 1);
fo (i, 1, n) {a[i << 1] = (s[i] == 'a'); a[i << 1 | 1] = 2;}
a[1] = 2; a[0] = 233; a[n + 1 << 1] = 234;
int x = 1, up = n << 1 | 1;
long long ans = 0;
fo (i, 2, up)
{
if (i <= x + f[x]) f[i] = std::min(x + f[x] - i, f[(x << 1) - i]);
while (a[i + f[i]] == a[i - f[i]]) ++f[i]; --f[i];
if (x + f[x] < i + f[i]) x = i;
if (f[i]) ans -= (f[i] + 1) >> 1;
}
int len = 1;
while (len <= n + n) len <<= 1;
for (int i = 0; i < len; ++i) w[i] = (complex) {cos(2 * pi * i / len), -sin(2 * pi * i / len)};
for (char ch = 'a'; ch <= 'b'; ++ch)
{
memset(c1, 0, sizeof c1);
memset(c2, 0, sizeof c2);
fo (i, 1, n) c1[i - 1].real = c2[i - 1].real = (s[i] == ch);
for (int i = 0; i < len; ++i) w[i].conj();
dft(c1, len);
dft(c2, len);
for (int i = 0; i < len; ++i) {c1[i] = c1[i] * c2[i]; w[i].conj();}
dft(c1, len);
for (int i = 0; i < len; ++i)
{
if (i & 1)
cnt[i] += ((int) ((c1[i].real + 0.1) / len)) >> 1;
else
{
if (s[(i >> 1) + 1] == ch)
{
cnt[i] += ((int) ((c1[i].real + 0.1) / len + 1)) >> 1;
}
else
{
cnt[i] += ((int) ((c1[i].real + 0.1) / len)) >> 1;
}
}
}
}
for (int i = 0; i < len; ++i)
ans = (1ll * ans + pow(2, cnt[i]) + mod - 1) % mod;
printf("%d", ans);
}
0 条评论