树链剖分入门练手题。
然鹅我居然调了一个上午+半个上午!
最后居然是少写了一个 f
导致 WA
了一片 [滑稽]。
总的来说,树链剖分其实并不是那么的难,只需要两个 dfs+线段树+lca 思想即可 AC
。
两遍 dfs 很好理解:
const int N=3e4+2;
struct Node{
int fa,dep,size,son,top,seg;
#define fa(x) tree[x].fa
#define d(x) tree[x].dep
#define s(x) tree[x].size
#define son(x) tree[x].son
#define top(x) tree[x].top
#define seg(x) tree[x].seg
}tree[N];
struct Edge{
int nxt,to;
#define nxt(x) edge[x].nxt
#define to(x) edge[x].to
}edge[N<<2];
int num[N<<2],rev[N<<2],sum[N<<2],Max[N<<2],head[N<<2];
int n,m,cnt,Ans_sum,Ans_max;
//以上是需要定义的东西
//声明一下这些东西的含义:
/*
fa[x]:x 在树中的父亲
dep[x]:x 在树中的深度
size[x]:x 的子树结点数(子树大小)
son[x]:x 的重儿子,即 u→son[u] 为重边
top[x]:x 所在重路径的顶部结点(深度最小)
seg[x]:x 在线段树中的位置(下标)
rev[x]: 线段树中第 x 个位置对应的树中结点编号,即 rev[seg[x]]=x
//上为树链剖分的定义所需
//第二个结构体为图
num 为权值,Max、sum 为线段树维护的最大值与和
head 为链式前向星的必备数组
*/
//第一遍 dfs 处理树链剖分七个值的前四个
inline void dfs1(int u,int f){
s(u)=1,fa(u)=f,d(u)=d(f)+1;
for(register int i=head[u];i;i=nxt(i)){
int v=to(i);if(v==f)continue;
dfs1(v,u);s(u)+=s(v);
if(s(v)>s(son(u)))son(u)=v;
}return;
}
//第二遍处理后三个
inline void dfs2(int u,int f){
if(son(u)){
seg(son(u))=++seg(0),top(son(u))=top(u);
rev[seg(0)]=son(u),dfs2(son(u),u);
}for(register int i=head[u];i;i=nxt(i)){
int v=to(i);if(!top(v)&&v!=f){
seg(v)=++seg(0),top(v)=v;
rev[seg(0)]=v;dfs2(v,u);
}
}return;
}
部分效果 (start 即为我们的 top(重链顶端)):
线段树单点修改即可:
inline void pushup(int x){
sum[x]=sum[x<<1]+sum[(x<<1)+1];
Max[x]=max(Max[x<<1],Max[(x<<1)+1]);
}
inline void build(int k,int l,int r){
int mid=(l+r)>>1;if(l==r)
{sum[k]=Max[k]=num[rev[l]];return;}
build(k<<1,l,mid);build((k<<1)+1,mid+1,r);pushup(k);
}
inline void change(int k,int l,int r,int val,int x){
if(x>r||x<l)return;int mid=(l+r)>>1;
if(l==r&&r==x){sum[k]=Max[k]=val;return;}
if(mid>=x)change(k<<1,l,mid,val,x);
if(mid<x)change((k<<1)+1,mid+1,r,val,x);pushup(k);
}
inline void query(int k,int l,int r,int L,int R){
if(L>r||R<l)return;
if(L<=l&&r<=R){
Ans_max=max(Ans_max,Max[k]);
Ans_sum+=sum[k];return;
}int mid=(l+r)>>1;
if(mid>=L)query(k<<1,l,mid,L,R);
if(mid<r)query((k<<1)+1,mid+1,r,L,R);
}
怎么询问呢?
先看一下我们的重链:
由于 dfs 的顺序,同一条重链上的节点在线段树中位置是连续的。
所以每次对于一个节点 x,我们只需要询问 x 到 top(x)
之间的路径即可 (连在一起的线段树可以直接区间询问),然后 x 再跳到 top(x) 的爸爸
(x->top(x) 已经询问完了),就这样一直往上跳,跳到最后 x 和另一个节点 y 都在同一条重链上即可 (即 top(x)==top(y))。
到了一条重链上,那么就可以直接查询了。
最后综合几次查询的结果,即为 x 到 y 直接的结果了。
就像下图一样:
注意:轻儿子的 top 当然是自己,即只记录自己的答案,重儿子可以跟着重链顶端一起询问 (不理解画个图,演示一下 dfs 的过程就好多了)
询问代码:
inline void ask(int x,int y){
int fx=top(x),fy=top(y);
while(fx!=fy){
if(d(fx)<d(fy))swap(fx,fy),swap(x,y);
query(1,1,seg(0),seg(fx),seg(x));
x=fa(fx),fx=top(x);
}if(d(x)>d(y))swap(x,y);
query(1,1,seg(0),seg(x),seg(y));return;
}
所以只需要再加一个建图就好了。
注意建双向边,dfs 的时候判断一下是不是 fa。
AC 代码:
//树链剖分模板
//总结:两遍 dfs+线段树+lca 思想,线段树很重要
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define A printf("A")
#define ld long double
#define RI register int
#define max(x,y) (x)>(y)?(x):(y)
#define min(x,y) (x)<(y)?(x):(y)
#define Match
using namespace std;
const int N=3e4+2;
struct Node{
int fa,dep,size,son,top,seg;
#define fa(x) tree[x].fa
#define d(x) tree[x].dep
#define s(x) tree[x].size
#define son(x) tree[x].son
#define top(x) tree[x].top
#define seg(x) tree[x].seg
}tree[N];
struct Edge{
int nxt,to;
#define nxt(x) edge[x].nxt
#define to(x) edge[x].to
}edge[N<<2];
int num[N<<2],rev[N<<2],sum[N<<2],Max[N<<2],head[N<<2];
int n,m,cnt,Ans_sum,Ans_max;
template <typename Tp> inline void IN(Tp &x){
int f=1;x=0;char ch=getchar();
while(ch<'0'||ch>'9')if(ch=='-')f=-1,ch=getchar();
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();x*=f;
}
inline void pushup(int x){
sum[x]=sum[x<<1]+sum[(x<<1)+1];
Max[x]=max(Max[x<<1],Max[(x<<1)+1]);
}
inline void build(int k,int l,int r){
int mid=(l+r)>>1;if(l==r)
{sum[k]=Max[k]=num[rev[l]];return;}
build(k<<1,l,mid);build((k<<1)+1,mid+1,r);pushup(k);
}
inline void change(int k,int l,int r,int val,int x){
if(x>r||x<l)return;int mid=(l+r)>>1;
if(l==r&&r==x){sum[k]=Max[k]=val;return;}
if(mid>=x)change(k<<1,l,mid,val,x);
if(mid<x)change((k<<1)+1,mid+1,r,val,x);pushup(k);
}
inline void query(int k,int l,int r,int L,int R){
if(L>r||R<l)return;
if(L<=l&&r<=R){
Ans_max=max(Ans_max,Max[k]);
Ans_sum+=sum[k];return;
}int mid=(l+r)>>1;
if(mid>=L)query(k<<1,l,mid,L,R);
if(mid<r)query((k<<1)+1,mid+1,r,L,R);
}
inline void dfs1(int u,int f){
s(u)=1,fa(u)=f,d(u)=d(f)+1;
for(register int i=head[u];i;i=nxt(i)){
int v=to(i);if(v==f)continue;
dfs1(v,u);s(u)+=s(v);
if(s(v)>s(son(u)))son(u)=v;
}return;
}
inline void dfs2(int u,int f){
if(son(u)){
seg(son(u))=++seg(0),top(son(u))=top(u);
rev[seg(0)]=son(u),dfs2(son(u),u);
}for(register int i=head[u];i;i=nxt(i)){
int v=to(i);if(!top(v)&&v!=f){
seg(v)=++seg(0),top(v)=v;
rev[seg(0)]=v;dfs2(v,u);
}
}return;
}
inline void add(int x,int y){
nxt(++cnt)=head[x],head[x]=cnt,to(cnt)=y;
nxt(++cnt)=head[y],head[y]=cnt,to(cnt)=x;
}
inline void ask(int x,int y){
int fx=top(x),fy=top(y);
while(fx!=fy){
if(d(fx)<d(fy))swap(fx,fy),swap(x,y);
query(1,1,seg(0),seg(fx),seg(x));
x=fa(fx),fx=top(x);
}if(d(x)>d(y))swap(x,y);
query(1,1,seg(0),seg(x),seg(y));return;
}char op[10];
int main(){
// freopen("1036.in","r",stdin);
// freopen("1036.out","w",stdout);
scanf("%d",&n);
for(register int x,y,i=1;i<n;++i)
{scanf("%d%d",&x,&y);add(x,y);}
for(register int i=1;i<=n;++i)scanf("%d",&num[i]);
dfs1(1,0);seg(0)=seg(1)=top(1)=rev[1]=1;
dfs2(1,0);build(1,1,seg(0));scanf("%d",&m);
for(register int x,y,i=1;i<=m;++i){
scanf("%s",op);scanf("%d%d",&x,&y);
if(op[0]=='C')change(1,1,seg(0),y,seg(x));
else{
Ans_sum=0,Ans_max=-(inf<<1);ask(x,y);
if(op[1]=='S')printf("%d\n",Ans_sum);
else printf("%d\n",Ans_max);
}
}return 0;
}
0 条评论