[树链剖分 + 堆] UOJ53. [UR #4] 追击圣诞老人

求前 \(K\) 小的路径,容易想到弄个堆怎样搞搞...

这题的关键在于每个点能到达的点很多,可以看作是树上几条链的形式。

这就让我们想起超级钢琴的那个套路,取出链上最小值后,分成两半放到堆里。

具体的,设 \(4\) 元组 \((sum,top,end,p)\) 表示:

当前路程和为 \(sum\) ,能走到头尾为 \(top, end\) 的链,当前取链中最小值 \(p\)

每次取当前最小,放回的是:沿 \(p\) 继续到新的点,以及更换末节点 \([top,p)\) ,\((p,end]\) 的情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>
using namespace std;
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
char ch=gc(); int res=0,ff=1;
while(!isdigit(ch)) ch=='-'?ff=-1:0, ch=gc();
while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0', ch=gc();
return res*ff;
}
const int maxn=500005,maxe=maxn;
int n,K,w[maxn];
int fir[maxn],nxt[maxe],son[maxe],tot;
void add(int x,int y){
son[++tot]=y; nxt[tot]=fir[x]; fir[x]=tot;
}
int dep[maxn],sz[maxn],hvy[maxn],top[maxn],pos[maxn],pre[maxn];
void dfs_info(int x){
sz[x]=1;
for(int j=fir[x];j;j=nxt[j]){
dep[son[j]]=dep[x]+1; pre[son[j]]=x; dfs_info(son[j]);
sz[x]+=sz[son[j]]; if(sz[hvy[x]]<sz[son[j]]) hvy[x]=son[j];
}
}
int c[maxn];
void dfs_chain(int x,int tp){
pos[x]=++c[0]; c[c[0]]=x; top[x]=tp;
if(hvy[x]) dfs_chain(hvy[x],tp);
for(int j=fir[x];j;j=nxt[j])
if(son[j]!=hvy[x]) dfs_chain(son[j],son[j]);
}
int seg_N,ch[maxn][2],mpos[maxn];
inline int Init_node(){
seg_N++; ch[seg_N][0]=ch[seg_N][1]=0; mpos[seg_N]=0;
return seg_N;
}
int merge(int x,int y){
return w[x]<w[y]?x:y;
}
int Build(int L,int R){
int p=Init_node();
if(L==R){ mpos[p]=c[L]; return p; }
int mid=(L+R)>>1;
ch[p][0]=Build(L,mid); ch[p][1]=Build(mid+1,R);
mpos[p]=merge(mpos[ch[p][0]],mpos[ch[p][1]]);
return p;
}
int Query(int p,int L,int R,int qL,int qR){
if(R<qL||qR<L) return 0;
if(qL<=L&&R<=qR) return mpos[p];
int mid=(L+R)>>1;
return merge(Query(ch[p][0],L,mid,qL,qR),Query(ch[p][1],mid+1,R,qL,qR));
}
int Chain_Query(int x,int d){
int res=0;
while(dep[top[x]]>=d){
res=merge(res,Query(1,1,n,pos[top[x]],pos[x]));
if(dep[top[x]]==d) return res; x=pre[top[x]];
}
return merge(res,Query(1,1,n,pos[x]-(dep[x]-d),pos[x]));
}

struct data{
int sum,x,ed,d,p;
data(int t1=0,int t2=0,int t3=0,int t4=0,int t5=0){ sum=t1; x=t2; ed=t3; d=t4; p=t5; }
bool operator < (const data &A)const{
return sum>A.sum;
}
};
data H[maxn*2];
int H_sz;
void Push(data t){
H[++H_sz]=t; push_heap(H+1,H+1+H_sz);
}
struct d_chain{
int ed1,d1,p1,ed2,p2,d2,ed3,d3,p3;
} b[maxn];
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[x]<dep[y]) swap(x,y);
x=pre[top[x]];
}
if(dep[x]<dep[y]) swap(x,y); return y;
}
void Get(int id,int x,int y,int z){
if(dep[y]<dep[z]) swap(y,z); if(dep[x]<dep[y]) swap(x,y);
int lca1=LCA(x,y), lca2=LCA(y,z), lca3=LCA(x,z), lca=LCA(lca1,lca3);
b[id].ed1=x; b[id].d1=dep[lca]; b[id].p1=Chain_Query(b[id].ed1,b[id].d1);
if(lca1!=y) b[id].ed2=y, b[id].d2=dep[lca1]+1, b[id].p2=Chain_Query(b[id].ed2,b[id].d2);
if(dep[lca2]<dep[lca3]) swap(lca2,lca3);
if(lca2!=z) b[id].ed3=z, b[id].d3=dep[lca2]+1, b[id].p3=Chain_Query(b[id].ed3,b[id].d3);
}
void Push_new(int sum,int x){
if(b[x].p1) Push(data(sum+w[b[x].p1],x,b[x].ed1,b[x].d1,b[x].p1));
if(b[x].p2) Push(data(sum+w[b[x].p2],x,b[x].ed2,b[x].d2,b[x].p2));
if(b[x].p3) Push(data(sum+w[b[x].p3],x,b[x].ed3,b[x].d3,b[x].p3));
}
int main(){
n=getint(); K=getint();
w[0]=1e9; for(int i=1;i<=n;i++) w[i]=getint();
for(int i=2;i<=n;i++){
int x=getint(); add(x,i);
}
dfs_info(1); dfs_chain(1,1);
Build(1,n);
for(int i=1;i<=n;i++){
int x=getint(),y=getint(),z=getint();
Get(i,x,y,z); Push(data(w[i],0,i,dep[i],i));
// printf("%d %d %d: ",x,y,z);
// printf("%d %d | %d %d | %d %d |\n",b[i].ed1,b[i].d1,b[i].ed2,b[i].d2,b[i].ed3,b[i].d3);
}
while(K){
data t=H[1]; pop_heap(H+1,H+1+H_sz--);
int d=t.d, ed=t.ed, p=t.p;
if(d<dep[p]){
int p2=Chain_Query(pre[p],d);
Push(data(t.sum-w[p]+w[p2],t.x,pre[p],d,p2));
}
if(dep[p]<dep[ed]){
int p2=Chain_Query(ed,d+1);
Push(data(t.sum-w[p]+w[p2],t.x,ed,d+1,p2));
}
// if(t.x==t.p) continue;
K--; printf("%d\n",t.sum);
Push_new(t.sum,p);
}
return 0;
}