[KD Tree] BZOJ3053 The Closest M Points

记一下 \(KD\) 树的模板。感觉这好玄学。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<bits/stdc++.h>
#define mp make_pair
const int maxn=100010;
using namespace std;
typedef pair<int,int> pii;
int abs(int x){ return x<0?-x:x; }
int sqr(int x){ return x*x; }
int now_id,K;
struct node{
int a[6],mn[6],mx[6],g,l,r;
int& operator [] (int x){ return a[x]; }
friend bool operator < (node a,node b){ return a[now_id]<b[now_id]; }
void Print(){
printf("%d",a[1]);
for(int i=2;i<=K;i++) printf(" %d",a[i]);
putchar('\n');
}
} Point[maxn];
priority_queue<pii> Q;
node T[maxn];
int rt,q[maxn],cnt;
void maintain(int x){
for(int i=1;i<=K;i++){
T[x].mn[i]=T[x].mx[i]=T[x][i];
if(T[x].l)
T[x].mn[i]=min(T[x].mn[i],T[T[x].l].mn[i]),
T[x].mx[i]=max(T[x].mx[i],T[T[x].l].mx[i]);
if(T[x].r)
T[x].mn[i]=min(T[x].mn[i],T[T[x].r].mn[i]),
T[x].mx[i]=max(T[x].mx[i],T[T[x].r].mx[i]);
}
}
int Build(int L,int R,int t){
if(L>R) return 0;
int mid=(L+R)>>1;
now_id=t; nth_element(Point+L,Point+mid,Point+R+1);
T[mid]=Point[mid];
T[mid].l=Build(L,mid-1,t%K+1); T[mid].r=Build(mid+1,R,t%K+1);
maintain(mid); return mid;
}
int dist(int g,node x){
int r=0;
if(g==0) return 1<<30;
for(int i=1;i<=K;i++) r+=sqr(T[g][i]-x[i]);
return r;
}
int wnt(int g,node x){
int r=0;
for(int i=1;i<=K;i++){
if(T[g].mn[i]>x[i]) r+=sqr(T[g].mn[i]-x[i]);
if(T[g].mx[i]<x[i]) r+=sqr(T[g].mx[i]-x[i]);
}
return r;
}
void Query(int g,node x){
if(!g) return ;
int d=dist(g,x),dl=wnt(T[g].l,x),dr=wnt(T[g].r,x);
if(d<Q.top().first) Q.pop(),Q.push(mp(d,g));
if(dl<dr){
if(dl<Q.top().first) Query(T[g].l,x);
if(dr<Q.top().first) Query(T[g].r,x);
}
else{
if(dr<Q.top().first) Query(T[g].r,x);
if(dl<Q.top().first) Query(T[g].l,x);
}
}
void Solve(node x,int w){
for(int i=1;i<=w;i++) Q.push(mp(1<<30,0));
Query(rt,x); cnt=0;
while(!Q.empty()){
q[++cnt]=Q.top().second;
Q.pop();
}
printf("the closest %d points are:\n",w);
for(int i=cnt;i;i--) T[q[i]].Print();
}
int n,m;
int main(){
while(scanf("%d%d",&n,&K)==2){
for(int i=1;i<=n;i++)
for(int j=1;j<=K;j++) scanf("%d",&Point[i][j]);
rt=Build(1,n,1);
scanf("%d",&m);
for(int i=1;i<=m;i++){
node x; for(int i=1;i<=K;i++) scanf("%d",&x[i]);
int y; scanf("%d",&y);
Solve(x,y);
}
}
return 0;
}