动态区间第 k 大的一种 $O(nlog^2n)$的树套树解法
题意:
给定一个序列,有 “将一个数修改为另一个” 的操作,和询问 “[l,r] 区间内的第 k 小值是几” 的询问。要求在 1s 内对一个长 $10^4$的序列完成 $10^4$组修改和询问。
思路:
1.
我最先想到的树套树的思路是:线段树套 Splay。
线段树的每一个区间 [l,r] 定义为原序列的 [l,r] 中的数字组成的排序 Splay。这样,对于每个修改操作,我们对 $logn$棵包含了这个位置的 Splay 删除一个数,添加一个数即可。对于每个询问操作 [l,r],我们需要二分答案 x,再在 logn 棵恰好组成 [l,r] 区间的 Splay 中查询 x 的排名。
由于线段树有 logn 层,每层的所有 Splay 合并后恰为一个完整的原序列,所以空间复杂度为 O(nlogn) 的。
经过上面的分析可知:修改复杂度为 $O(nlog^2n)$,查询为 $O(nlog^3n)$。实践证明可以通过所有测试点。洛谷最慢测试点 340ms.
2.
第二种思路,即本文终点介绍的思路,两种操作都是 $O(nlog^2n)$的,只是还需要对所有出现过的数字离散化。
刚才,我们的线段树对应的是区间,Splay 对应的是值。如果互换一下呢?
现在线段树是建立在离散化后的实数域上的。线段树区间 [l,r] 定义为原序列中值属于 [l,r] 的所有值的下标的排序 Splay。即:将值属于 [l,r] 的所有位置提取为一个新的序列,保存在这个 Splay 里。
对于修改操作,我们依旧修改 $logn$棵 Splay。如果我们将 a[i] 修改为 b,则将所有的线段树节点 $[l,r](l\leq a[i]\leq r)$中的 Splay 中删除 i。同时向所有的线段树节点 $[l,r](l\leq b\leq r)$中的 Splay 添加 i。故一次修改操作的复杂度为 $O(log^2n)$。
对于查询操作,注意到这棵线段树是支持前缀和的。即实数区间 [l,r] 内的数 x$(a\leq x \leq b)$的个数等于 [1,r] 内的个数减 [1,l-1] 内的个数(这里的 x 就是原序列的下标)。所以我们只需要在树上二分即可。如果现在我们确定了 k 小值一定在实数区间 [a,b] 内,那么如果 [a,(a+b)/2] 内的 Splay 中满足上述条件的节点小于 k,则 k 小值一定在 [(a+b)/2+1,b] 内,反之同理。这样的查询操作只需要对 logn 个线段树节点进行查询,故一次询问的时间复杂度为 $O(log^2 n)$。
实践证明可以通过所有测试点,洛谷最慢测试点 92ms.
细节问题
这种方法虽然省去了一个 log,但是其代码量却比之前多了一个 log。
以下是一些需要注意的地方:
- 离散化不但要离散原序列的值,也要离散修改出的值。
- 在最初插入节点时最好使用类似线段树建树一样的归并方式,这样可以降低常数。实践证明不这样洛谷最慢点为 296ms.
- 最好不要用这种方法因为我打了 200 行。
/*
A data structure used to maintain interval kth number
With splays in a segment tree
*/
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>
#define MX 61005
#define mid ((l+r)>>1)
#define ls (a<<1)
#define rs (a<<1|1)
using namespace std;
typedef struct splnode
{
int x,f,siz,s[2];
}node;
typedef struct tqeury
{
int l,r,x,t;
}query;
query qur[MX];
node t[MX*18];
int seq[MX],real[MX];
vector<int>ord[MX];
map<int,int>mp;
map<int,int>::iterator itr;
int tnum;
int n,m;
int bar[MX],bnum;
typedef struct trenode
{
int root,l,r;
inline int pos(int a){return t[t[a].f].s[1]==a;}
inline void upd(int a){t[a].siz=t[t[a].s[0]].siz+t[t[a].s[1]].siz+1;}
inline void rot(int a)
{
int f=t[a].f,g=t[f].f,p=pos(a),q=pos(f);
t[f].s[p]=t[a].s[!p],t[a].s[!p]=f,t[f].f=a;
if(t[f].s[p])t[t[f].s[p]].f=f;
if(t[a].f=g)t[g].s[q]=a;
upd(f),upd(a);
}
inline void spl(int tar,int a)
{
while(t[a].f!=tar)
if(t[t[a].f].f==tar)rot(a);
else if(pos(a)==pos(t[a].f))rot(t[a].f),rot(a);
else rot(a),rot(a);
if(!tar)root=a;
}
int merg(int f,int l,int r)
{
if(l>r)return 0;
int a=++tnum;
t[a]=(node){bar[mid],f,1,0,0};
t[a].s[0]=merg(a,l,mid-1);
t[a].s[1]=merg(a,mid+1,r);
upd(a);
return a;
}
void insrt(int &a,int f,int x)
{
if(!a)t[a=++tnum]=(node){x,f,1,0,0},spl(0,tnum);
else if(x<t[a].x)insrt(t[a].s[0],a,x);
else insrt(t[a].s[1],a,x);
}
int findn(int a,int x)
{
if(!a)return 0;
else if(t[a].x==x)return a;
else if(x<t[a].x)return findn(t[a].s[0],x);
else return findn(t[a].s[1],x);
}
void del(int x)
{
int a=findn(root,x);
spl(0,a);
int la=t[a].s[0],ra=t[a].s[1];
while(t[la].s[1])la=t[la].s[1];
spl(a,la);
t[la].s[1]=ra,t[ra].f=la,t[root=la].f=0;
spl(0,ra);
}
int rank(int a,int x)
{
if(!a)return 0;
else if(x>=t[a].x)return rank(t[a].s[1],x)+t[t[a].s[0]].siz+1;
else return rank(t[a].s[0],x);
}
}segt;
segt tre[MX*4];
void build(int a,int l,int r)
{
tre[a].l=l,tre[a].r=r;
if(l<r)build(ls,l,mid),build(rs,mid+1,r);
bar[1]=-MX;
bnum=1;
for(int p=l;p<=r;p++)
for(int i=0;i<ord[p].size();i++)
bar[++bnum]=ord[p][i];
bar[++bnum]=MX;
sort(bar+1,bar+bnum+1);
tre[a].root=tre[a].merg(0,1,bnum);
}
void del(int a,int p,int x)
{
int l=tre[a].l,r=tre[a].r;
tre[a].del(x);
if(l==r)return;
else if(p<=mid)del(ls,p,x);
else del(rs,p,x);
}
void ins(int a,int p,int x)
{
int l=tre[a].l,r=tre[a].r;
tre[a].insrt(tre[a].root,0,x);
if(l==r)return;
else if(p<=mid)ins(ls,p,x);
else ins(rs,p,x);
}
int kth(int a,int ql,int qr,int k)
{
int dlt=tre[ls].rank(tre[ls].root,qr)-tre[ls].rank(tre[ls].root,ql);
if(tre[a].l==tre[a].r)return tre[a].r;
else if(dlt<k)return kth(rs,ql,qr,k-dlt);
else return kth(ls,ql,qr,k);
}
void lsh()
{
int x;
for(x=1;x<=n;x++)mp[seq[x]]=1;
for(x=1;x<=m;x++)if(qur[x].t==0)mp[qur[x].x]=1;
for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)itr->second=x;
for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)real[itr->second]=itr->first;
for(x=1;x<=n;x++)seq[x]=mp[seq[x]];
for(x=1;x<=m;x++)if(qur[x].t==0)qur[x].x=mp[qur[x].x];
}
void inpt()
{
char str[10];
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&seq[i]);
for(int i=1;i<=m;i++)
{
scanf("%s",str);
if(str[0]=='C')qur[i].t=0,scanf("%d%d",&qur[i].l,&qur[i].x);
else qur[i].t=1,scanf("%d%d%d",&qur[i].l,&qur[i].r,&qur[i].x);
}
lsh();
for(int i=1;i<=n;i++)ord[seq[i]].push_back(i);
n=mp.size();
build(1,1,n);
}
void work()
{
for(int i=1;i<=m;i++)
{
if(qur[i].t==0)
{
del(1,seq[qur[i].l],qur[i].l);
ins(1,qur[i].x,qur[i].l);
seq[qur[i].l]=qur[i].x;
}
else printf("%d\n",real[kth(1,qur[i].l-1,qur[i].r,qur[i].x)]);
}
}
void init()
{
tnum=0;
mp.clear();
for(int i=1;i<=n;i++)ord[i].clear();
}
int main()
{
int T;
scanf("%d",&T);
for(int i=1;i<=T;i++)
{
init();
inpt();
work();
}
return 0;
}
1 条评论
konnyakuxzy · 2018年1月9日 8:35 下午
我去这代码确实挺长的 QvQ
您码力太强了 Orz
不过确实奇怪网上居然没有这种权值线段树の题解