此生此世做过的第一恶心的树型 DP
题意:
很简单:就是要求一棵树上相距距离不超过 k 的点有几对。多组数据,n<=10000
分析:
看到不能 $O(n^2)$赶紧想有没有带 $log$的算法。没想到。于是百度。网上说是树的点分治,大概看懂了。
算法思想:
对于一棵树,其中任何一条路径都有:要么经过根,要么不经过根。于是我们把这棵树的根节点揪出来,寻找所有经过根的合法路径,然后把这个根删掉,得到很多子树,对子树再进行上述计算。由于新的子树上的路径绝对不经过旧的大树的根,所以路径不会有重复。又由于很显然地每一条路径都只能属于一棵子树也必然属于一棵树,所以这种方法是正确的。经过合理的选择根节点,我们只计算了 $logn$层,所以这种方法的复杂度为 $O(logn)*P(n)$,$P(n)$的复杂度取决于你寻找路径的算法的好坏。
下面我们讨论如何寻找经过一个树的根的路径总数。
首先,我们可以获得这棵树中所有节点到根的距离。那么所有和小于 k 的点对有可能构成合法路径的两端。为什么只是有可能呢?因为两个点有可能出现在这棵树的同一个子树上,他们构成的路径不经过根。
设 $A$为满足 $dis[x]+dis[y]\leq k$的 $(x,y)$的数量,$B$为满足 $dis[x]+dis[y]\leq k$且 $x,y$所在子树相同的 $(x,y)$数量,那么这棵树中经过根节点的路径条数就是 $A-B$。
我们可以 $O(n)$求出 $dis[i]$,$O(nlogn)$将 $dis[]$排序,再 $O(n)$利用单调性找出每一个 x 所对应的 $dis$最大的 $y$(当 $dis[x]$增加时,$dis[y]$不会增加, 呈单调递减),也就是 $A$。对于这棵树的每一棵子树,我们又可以 $O(n)$求出所有 $B$。然后 $A-B$就是答案。这样的理想时间复杂度为 $O(nlognlogn)$
还需要注意几点:
1. 单纯的将子树分治是不可行的,因为出题人会`专门把树扯成一条链让你 $O(n^2logn)$地吃屎。于是我们需要专门用一个 DP 找出树的重心不断进行分治,而复杂度为 $O(nlognlogn)$。
2. 这一道题很无耻的卡 memset(), 删掉 memset 你就奇迹般地从 TLE 变成了 547ms。此题还会与分段式桶排发生反应,我的比 sort 在 1e7 下快 10 倍的桶排竟然会 TLE, 而 sort 奇迹般 AC。考试时如果遇到这种让我做一下午一晚上的题,我一定会说: 打暴力。如果非要给这个暴力加上一个期限,我希望是+1s。(+1s 就可以用 memset)
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#define MX 10010
using namespace std;
int fst[MX],nxt[MX*2],v[MX*2],w[MX*2],lnum;
int vis[MX];
int n,k;
int mx[MX],sum[MX];
inline void addeg(int nu,int nv,int nw)
{
lnum++;
nxt[lnum]=fst[nu];
fst[nu]=lnum;
v[lnum]=nv;
w[lnum]=nw;
}
inline void init()
{
lnum=-1;
memset(fst,-1,sizeof(fst));
}
inline void input()
{
int a,b,c;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&a,&b,&c);
addeg(a,b,c);
addeg(b,a,c);
}
}
int root,sz;
void _getroot(int x,int fa)
{
sum[x]=mx[x]=0;
for(int i=fst[x];i!=-1;i=nxt[i])
{
if(v[i]==fa||vis[v[i]])continue;
_getroot(v[i],x);
sum[x]+=sum[v[i]]+1;
mx[x]=max(mx[x],sum[v[i]]+1);
}
mx[x]=max(mx[x],sz-sum[x]-1);
if(mx[x]<mx[root])root=x;
}
inline void getroot(int x,int fa)
{
root=0;
mx[root]=99999999;
_getroot(x,fa);
}
int q[MX],dp[MX];
int dis[MX];
inline void getdep(int pred,int x,int fa)
{
int h=1,t=1,now;
memset(dp,0xff,sizeof(dp));
dis[0]=0;
dp[x]=pred;
dis[++dis[0]]=pred;
q[h]=x;
while(h>=t)
{
now=q[t++];
for(int i=fst[now];i!=-1;i=nxt[i])
{
if(v[i]==fa||vis[v[i]]||dp[v[i]]!=-1)continue;
dp[v[i]]=dp[now]+w[i];
dis[++dis[0]]=dp[v[i]];
q[++h]=v[i];
}
}
}
int tdis[MX];
int sch(int x,int fa)
{
int a=0,b=0;
vis[x]=1;
tdis[0]=0;
for(int i=fst[x];i!=-1;i=nxt[i])
{
if(v[i]==fa||vis[v[i]])continue;
getdep(w[i],v[i],x);
sort(dis+1,dis+dis[0]+1);
for(int j=1;j<=dis[0];j++)tdis[++tdis[0]]=dis[j];
for(int j=1,c=dis[0];j<=dis[0];j++)
{
while(dis[c]+dis[j]>k&&c>=1)c--;
if(c<=j)break;
b+=c-j;
}
}
sort(tdis+1,tdis+tdis[0]+1);
for(int j=tdis[0];j>=1;j--)if(tdis[j]<=k){a+=j;break;}
for(int i=1,j=tdis[0];i<=tdis[0];i++)
{
while(tdis[j]+tdis[i]>k&&j>=1)j--;
if(j<=i)break;
a+=j-i;
}
a-=b;
for(int i=fst[x];i!=-1;i=nxt[i])
{
if(v[i]==fa||vis[v[i]])continue;
sz=sum[v[i]]+1;
getroot(v[i],x);
a+=sch(root,x);
}
return a;
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
memset(vis,0,sizeof(vis));
if(n==0&&k==0)break;
init();
input();
sz=n;
getroot(1,0);
printf("%d\n",sch(root,0));
}
return 0;
}
0 条评论