论如何使用 K-D tree 卡过”HNOI2015 接水果”
思路:
首先,我们会比较自然地想到将盘子和水果的关系抽象成整棵树的 DFS 序上的点对。
如果一条树上路径[a,b] 是[c,d] 的子路径,那么 c 和 d 一定会在与 a,b 有关的特定的区间内。
因此求多少已知路径包含给定路径将非常容易,因为这个问题可以转化为有多少点对 (x,y) 满足 $x\in [l_1,r_1],y\in [l_2,r_2]$。这个问题可以由 K-D 树或者扫描线或者什么诡异的方法解决,如用 k-d 树,我们只需要实现区间加和单点修改。
接着我们在此基础上,考虑如何满足 “第 k 大” 这个限制。
如果我们二分每这个第 k 大值,该问题又可以转化为判定性问题:求有多少小于该值的满足条件的点对。
所以我们对每个水果,同时二分答案,按二分的第 k 大值排序。盘子也按权值排序。将水果和盘子按权值从小到达,把盘子插入到 k-d 树中,用水果询问。这样我们就知道每个水果路径包含了多少比他小的盘子路径,利用这个数量继续二分即可。
这是一种类似整体二分的方法,但相比整体二分,这种方法更像二分答案。
时间复杂度 $O(n\sqrt{n}log{n})$
卡·常数 的诅咒
然而非常激动地想到这种做法,非常激动地码完,非常激动地过了样例,却只有 20 分,剩下全部 TLE。开 O2 后还是只有 70 分。
找到本地数据测试后,发现有一些精心构造的数据需要 5 秒才能通过,于是理所当然地在洛谷上 TLE 了。
于是我开始优化代码
- 问题 1:同时二分时有些的 l,r 区间已经重合,没有必要处理这些询问。
- 解决 1:特判掉这些没用的询问-0.07s
- 问题 2:二分的区间太大,没有必要二分 0 到 1e9 的所有数。
- 解决 2:将所有可能作为答案的数离散化,二分这些数的编号-0.5s
- 问题 3:gprof 后发现各种运算符重载和 min,max 等比较耗时
- 解决 3:运算符用 const&重载,手写 min,max-0.3s
- 问题 4:读入也许比较耗时
- 解决 4:读入优化-0.04s
- 问题 5:最影响整体复杂度的是 kd 树的区间加
- 解决 5:对 kd 树实行标记永久化,榨干区间加的常数-0.6s
- 问题 6:二分答案时会多次询问同样的矩形,每次在 kd 树中递归非常浪费时间
- 解决 6:预处理所有可能用到的矩形在 kd 树中用到了哪些节点,以及对这些节点实行了什么操作,真正区间加时直接访问这些节点-0.8s
至此,我将 5s 的代码优化成了 2.7s,于是我就仰天长啸,拍案而起,但最终还是
觉得对得起浪费掉的这两个晚上的。
(由于各种原因,代码非常长)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse3","sse2","sse")
#pragma GCC target("avx","sse4","sse4.1","sse4.2","ssse3")
#pragma GCC target("f16c")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"
#pragma GCC diagnostic error "-std=c++14"
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <ctime>
#define MX 40004
int D;
int real[MX],rnum;
int gfake(int a)
{
return std::lower_bound(real+1,real+rnum+1,a)-real;
}
void lsh()
{
std::sort(real+1,real+rnum+1);
rnum=std::unique(real+1,real+rnum+1)-real-1;
}
struct node
{
int p[2],mn[2],mx[2],s[2],cnt,laz;
bool operator < (const node& t)const{return p[D]!=t.p[D]?p[D]<t.p[D]:p[D^1]<t.p[D^1];}
}tseq[MX*2];
struct edge
{
int u,v,w;
}e[MX*2],plt[MX],qur[MX];
struct bina
{
int t,u,v,k,l,r,m,a,id; //type; u; v; kth; l; r; mid; ans; id;
bool operator < (const bina& tmp)const{return real[m]==real[tmp.m]?t<tmp.t:real[m]<real[tmp.m];}
}bin[MX*2];
int n,p,q;
int fst[MX],nxt[MX*2],lnum;
int dfn[MX],rak[MX],mst[MX][2],dcnt;
int out[MX],rout[MX][1000],rtyp[MX][1000],rsiz[MX];
int mmin(const int &a,const int &b){return a>b?b:a;}
int mmax(const int &a,const int &b){return a<b?b:a;}
struct KDT
{
node tre[MX*2];
int cnt,rot;
void init(){for(int i=1;i<=cnt;++i)tre[i].cnt=tre[i].laz=0;}
void upd(int a)
{
tre[a].mn[0]=tre[a].mx[0]=tre[a].p[0];
if(tre[a].s[0])
{
tre[a].mn[0]=mmin(tre[a].mn[0],tre[tre[a].s[0]].mn[0]),
tre[a].mx[0]=mmax(tre[a].mx[0],tre[tre[a].s[0]].mx[0]);
}
if(tre[a].s[1])
{
tre[a].mn[0]=mmin(tre[a].mn[0],tre[tre[a].s[1]].mn[0]),
tre[a].mx[0]=mmax(tre[a].mx[0],tre[tre[a].s[1]].mx[0]);
}
tre[a].mn[1]=tre[a].mx[1]=tre[a].p[1];
if(tre[a].s[0])
{
tre[a].mn[1]=mmin(tre[a].mn[1],tre[tre[a].s[0]].mn[1]),
tre[a].mx[1]=mmax(tre[a].mx[1],tre[tre[a].s[0]].mx[1]);
}
if(tre[a].s[1])
{
tre[a].mn[1]=mmin(tre[a].mn[1],tre[tre[a].s[1]].mn[1]),
tre[a].mx[1]=mmax(tre[a].mx[1],tre[tre[a].s[1]].mx[1]);
}
}
void psd(int a)
{
if(!tre[a].laz)return;
tre[tre[a].s[0]].cnt+=tre[a].laz,
tre[tre[a].s[1]].cnt+=tre[a].laz,
tre[tre[a].s[0]].laz+=tre[a].laz,
tre[tre[a].s[1]].laz+=tre[a].laz,
tre[a].laz=0;
}
bool itsct(const node& a,const node& b)
{
return (*a.mn<=*b.mx&&*b.mn<=*a.mx&&
*(a.mn+1)<=*(b.mx+1)&&*(b.mn+1)<=*(a.mx+1));
}
void build(int &a,int l,int r,int d)
{
if(l>r)return;
D=d;
int mid=(l+r)>>1;
std::nth_element(tseq+l,tseq+mid,tseq+r+1);
tre[a=++cnt]=tseq[mid];
build(tre[a].s[0],l,mid-1,d^1);
build(tre[a].s[1],mid+1,r,d^1);
upd(a);
}
void pre_add(int a,node &t,int tar)
{
++rsiz[tar];
rout[tar][rsiz[tar]]=a;
if( *t.mn<=*tre[a].mn&&*tre[a].mx<=*t.mx&&
*(t.mn+1)<=*(tre[a].mn+1)&&*(tre[a].mx+1)<=*(t.mx+1))
{rtyp[tar][rsiz[tar]]|=1;return;}
if( *t.mn<=*tre[a].p&&*tre[a].p<=*t.mx&&
*(t.mn+1)<=*(tre[a].p+1)&&*(tre[a].p+1)<=*(t.mx+1))
{rtyp[tar][rsiz[tar]]|=2;}
int ls=*tre[a].s,rs=*(tre[a].s+1);
if( ls&&itsct(*(tre+ls),t))pre_add(ls,t,tar);
if( rs&&itsct(*(tre+rs),t))pre_add(rs,t,tar);
}
void add(int id)
{
for(int i=1;i<=rsiz[id];++i)
{
if(rtyp[id][i]&1)++tre[rout[id][i]].laz;
if(rtyp[id][i]&2)++tre[rout[id][i]].cnt;
}
}
int query(int a,node &t,int d)
{
if(!a)return 0;
D=d;
if(t.p[0]==tre[a].p[0]&&t.p[1]==tre[a].p[1])return tre[a].cnt+tre[a].laz;
else if(t<tre[a])return query(tre[a].s[0],t,d^1)+tre[a].laz;
else return query(tre[a].s[1],t,d^1)+tre[a].laz;
}
}ktr;
struct LCA
{
int fa[17][MX],dep[MX];
void init()
{
for(int i=1;i<=16;++i)
for(int j=1;j<=n;++j)
fa[i][j]=fa[i-1][fa[i-1][j]];
}
int lca(int a,int b)
{
if(dep[a]<dep[b])a^=b,b^=a,a^=b;
for(int i=16;i>=0;--i)
if(dep[fa[i][a]]>=dep[b])
a=fa[i][a];
if(a==b)return a;
for(int i=16;i>=0;--i)
if(fa[i][a]!=fa[i][b])
a=fa[i][a],b=fa[i][b];
return fa[0][a];
}
int son(int a,int b)
{
for(int i=16;i>=0;--i)
if(dep[fa[i][b]]>dep[a])
b=fa[i][b];
return b;
}
}lca;
void addeg(int nu,int nv)
{
nxt[++lnum]=fst[nu];
fst[nu]=lnum;
e[lnum]=(edge){nu,nv};
}
int read()
{
int x=0;char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x;
}
void input()
{
int a,b;
memset(fst,0xff,sizeof(fst)),lnum=-1;
n=read(),p=read(),q=read();
for(int i=1;i<n;++i)
{
a=read(),b=read();
addeg(a,b);
addeg(b,a);
}
for(int i=1;i<=p;++i)plt[i].u=read(),plt[i].v=read(),plt[i].w=read();
for(int i=1;i<=q;++i)qur[i].u=read(),qur[i].v=read(),qur[i].w=read();
}
void dfs(int x,int f,int d)
{
lca.fa[0][x]=f;
lca.dep[x]=d;
dfn[x]=++dcnt;
rak[dcnt]=x;
mst[x][0]=dcnt;
for(int i=fst[x];i!=-1;i=nxt[i])
if(e[i].v!=f)
dfs(e[i].v,x,d+1);
mst[x][1]=dcnt;
}
void work()
{
node t;int tmp=0,f;
dfs(1,0,1);
lca.init();
for(int i=1;i<=p;++i)real[++rnum]=plt[i].w;
lsh();
for(int i=1;i<=p;++i)
{
int u=plt[i].u,v=plt[i].v;
if(dfn[u]>dfn[v])u^=v,v^=u,u^=v;
bin[i]=(bina){0,u,v,0,0,0,gfake(plt[i].w),0,i};
}
for(int i=1;i<=q;++i)
{
int u=qur[i].u,v=qur[i].v;
if(dfn[u]>dfn[v])u^=v,v^=u,u^=v;
bin[p+i]=(bina){1,u,v,qur[i].w,1,rnum,(rnum+1)>>1,0,i};
}
for(int i=1;i<=p+q;++i)
if(bin[i].t)
{
int du=dfn[bin[i].u],dv=dfn[bin[i].v];
tseq[++tmp]=(node){du,dv,du,dv,du,dv,0,0,0,0};
}
ktr.build(ktr.rot,1,tmp,0);
for(int j=1;j<=p+q;j++)
if(!bin[j].t)
{
int u=bin[j].u,v=bin[j].v,g=lca.lca(u,v);
if(g!=u&&g!=v)
{
t=(node){0,0,mst[u][0],mst[v][0],mst[u][1],mst[v][1],0,0,0,0};
if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
}
else
{
int x=lca.son(u,v);
t=(node){0,0,1,mst[v][0],mst[x][0]-1,mst[v][1],0,0,0,0};
if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
t=(node){0,0,mst[v][0],mst[x][1]+1,mst[v][1],n,0,0,0,0};
if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
}
}
for(int i=1;i<=17;++i)
{
f=0;
ktr.init();
std::sort(bin+1,bin+p+q+1);
for(int j=1;j<=p+q;++j)
{
if(!bin[j].t)
{
int u=bin[j].u,v=bin[j].v,g=lca.lca(u,v);
if(g!=u&&g!=v)ktr.add(bin[j].id);
else ktr.add(bin[j].id);
}
else if(bin[j].l<bin[j].r)
{
f=1;
t=(node){dfn[bin[j].u],dfn[bin[j].v],0,0,0,0,0,0,0,0};
bin[j].a=ktr.query(ktr.rot,t,0);
if(bin[j].a>=bin[j].k)bin[j].r=bin[j].m;
else bin[j].l=bin[j].m+1;
bin[j].m=(bin[j].l+bin[j].r)>>1;
}
}
if(!f)break;
}
for(int i=1;i<=p+q;++i)if(bin[i].t)out[bin[i].id]=bin[i].m;
for(int i=1;i<=q;++i)printf("%d\n",real[out[i]]);
}
int main()
{
input();
work();
return 0;
}
0 条评论