题目分析
有人问起我学会的第一个高级数据结构是什么。
我说是 spaly。
在 HNOI2017 的考场上学会的。
俗话说的好,双旋的 splay,单旋的 spaly,不旋的 saply,O(1) 的 asply,那么我们就来用 splay 做一做这道题。
首先我们手模一发单旋最小值操作。会发现,假如最小值节点是 x,那么这个操作就是把 x 放到根,x 的右子树给他原来的父亲当左子树,把原来的根节点给它做右子树。
思考思考就会发现,x 的右子树的 dfs 序应该是连续的,准确的说,以 x 为根的子树应该是 spaly 的 dfs 序从左边开始的一段连续的区间,且这个区间里的所有节点的深度都要大于等于 x 的深度。
现在我们用一棵 splay 来维护,splay 中节点的顺序就是按照权值排序,然后维护一下每个节点的 dep 值(深度)和其子树里的最小深度,然后每种操作的方法如下:
1. 插入:我们寻找 x 的前驱和后继,发现要么前驱是后继的父亲,要么后继是前驱的父亲(因为前驱和后继的 dfs 序一定相邻,所以这两个节点一定相邻),所以新加入节点的深度就是 max(dep(前驱),dep(后继))+1。除此之外,就简单地将新节点插入 splay 中即可。
2. 单旋最小/大值:找到 x 的右/左子树代表的区间长度,首先将所有节点的 dep +1,然后将 x 的右/左子树的节点 dep -1,然后再将 x 的 dep 单点赋值成 1.
3. 删除:删除 x 节点,将所有节点的深度都-1
这道题最难的地方,果然还是细节处理。我调试了两个半小时。由于每个人的写法不同,不予赘述我错了哪些细节,附赠一个丑丑的数据生成器,加油对拍吧。
代码
#include<bits/stdc++.h>
using namespace std;
int read() {
int q=0;char ch=' ';
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q;
}
const int N=100005,inf=0x3f3f3f3f;
int m,rt,n;
int s[N][2],f[N],dep[N],v[N],laz[N],mn[N],sz[N];
void up(int x) {
sz[x]=sz[s[x][0]]+sz[s[x][1]]+1;
mn[x]=min(min(mn[s[x][0]],mn[s[x][1]]),dep[x]);
}
void pd(int x) {
if(!laz[x]) return;
int ls=s[x][0],rs=s[x][1],t=laz[x];
if(ls) dep[ls]+=t,mn[ls]+=t,laz[ls]+=t;
if(rs) dep[rs]+=t,mn[rs]+=t,laz[rs]+=t;
laz[x]=0;
}
int is(int x) {return s[f[x]][1]==x;}
void spin(int x,int &mb) {
int fa=f[x],g=f[fa],t=is(x);
if(f[x]==mb) mb=x;
else s[g][is(fa)]=x;
f[x]=g,f[fa]=x,f[s[x][t^1]]=fa;
s[fa][t]=s[x][t^1],s[x][t^1]=fa;
up(fa),up(x);
}
void splay(int x,int &mb) {
while(x!=mb) {
if(f[x]!=mb) {
if(is(x)^is(f[x])) spin(x,mb);
else spin(f[x],mb);
}
spin(x,mb);
}
}
int find(int x,int num) {//寻找 dfs 序第 num 的节点
pd(x);
if(sz[s[x][0]]+1==num) return x;
if(sz[s[x][0]]>=num) return find(s[x][0],num);
else return find(s[x][1],num-sz[s[x][0]]-1);
}
void add(int l,int r,int num) {//区间加
int x=find(rt,l-1),y=find(rt,r+1);
splay(x,rt),splay(y,s[x][1]);
laz[s[y][0]]+=num,dep[s[y][0]]+=num,mn[s[y][0]]+=num;
up(y),up(x);//注意 pushup
}
int pre(int x,int num) {//前驱
if(!x) return 0;
pd(x);
if(v[x]<num) {int kl=pre(s[x][1],num);return kl?kl:x;}
else return pre(s[x][0],num);
}
int nxt(int x,int num) {//后继
if(!x) return 0;
pd(x);
if(v[x]>num) {int kl=nxt(s[x][0],num);return kl?kl:x;}
else return nxt(s[x][1],num);
}
void ins(int &x,int num,int d,int las) {//插入
if(!x) {x=++n,f[x]=las,v[x]=num,dep[x]=mn[x]=d,sz[x]=1;return;}
pd(x);
if(num<v[x]) ins(s[x][0],num,d,x);
else ins(s[x][1],num,d,x);
up(x);
}
int getl(int x,int num) {//获得从左边开始的连续的 dep[x]>=num 的区间长度
if(!x) return 0;
pd(x);
if(dep[x]>=num&&mn[s[x][0]]>=num) return sz[s[x][0]]+1+getl(s[x][1],num);
else return getl(s[x][0],num);
}
int getr(int x,int num) {//获得从右边开始的连续的 dep[x]>=num 的区间长度
if(!x) return 0;
pd(x);
if(dep[x]>=num&&mn[s[x][1]]>=num) return sz[s[x][1]]+1+getr(s[x][0],num);
else return getr(s[x][1],num);
}
void chan(int x,int num) {//单点修改
pd(x);
if(v[x]==num) {mn[x]=dep[x]=1;return;}
if(num<v[x]) chan(s[x][0],num);
else chan(s[x][1],num);
up(x);
}
void del(int x) {//删除
splay(x,rt);
if(s[x][0]*s[x][1]==0) rt=s[x][0]+s[x][1],f[rt]=0;
else {
int y=s[x][1];
while(s[y][0]) pd(y),y=s[y][0];
s[y][0]=s[x][0],f[s[x][0]]=y,rt=s[x][1],f[rt]=0;
while(y) up(y),y=f[y];//记得 pushup
}
}
int main()
{
int x,y;
m=read();
mn[0]=inf,ins(rt,-inf,inf,0),ins(rt,inf,inf,0);
while(m--) {
int bj=read();
if(bj==1) {
x=read();int a=pre(rt,x),b=nxt(rt,x);
a=((a==1||a==2)?0:dep[a]),b=((b==1||b==2)?0:dep[b]);
printf("%d\n",max(a,b)+1);
ins(rt,x,max(a,b)+1,0),splay(n,rt);//这个 splay 用于维护平衡
}
if(bj==2||bj==4) {
x=find(rt,2),printf("%d\n",dep[x]);
y=min(getl(rt,dep[x]),sz[rt]-1);
add(2,sz[rt]-1,1),add(2,y,-1);
chan(rt,v[x]);
}
if(bj==3||bj==5) {
x=find(rt,sz[rt]-1),printf("%d\n",dep[x]);
y=min(getr(rt,dep[x]),sz[rt]-1);
add(2,sz[rt]-1,1),add(sz[rt]-y+1,sz[rt]-1,-1);
chan(rt,v[x]);
}
if(bj==4||bj==5) del(x),add(2,sz[rt]-1,-1);
}
return 0;
}
数据生成器
#include<bits/stdc++.h>
using namespace std;
int a[100005],js,n;
void ins() {
++js;
int x=rand()%20+1;
while(a[x]) x=rand()%20+1;
a[x]=1;printf("1 %d\n",x);
}
int main()
{
srand(time(NULL));
n=rand()%10+1,printf("%d\n",n);
while(n--) {
if(!js) ins();
else {
int bj=rand()%5+1;
if(bj==1) ins();
else printf("%d\n",bj);
if(bj==4||bj==5) --js;
}
}
return 0;
}
0 条评论