hdu 4347 The Closest M Points KD树 k-邻近算法

缘起

传送门: hdu4347

简而言之就是给你一堆N维空间中的点, 度量是欧式的. 然后给你一个点, 要你求出这堆点中离该点最近的前m个点, 并且按照由近及远的顺序输出.

分析

这是KD树的模板题. kd树(k-dimension 树)是bst的高维推广. 我们知道bst可以实现对一维数据的二分查找. 注意,bst其实和二分查找是完全类似的, 前者并不能比后者更快的找到目标,但是前者较后者有一个极大的好处——可以推广到高维数据. 而推广到高维的bst就是kd树. 而求给顶点的m邻近点集就是kd树的典型运用.

盗一张kd树的经典图

首先是二维数据集.

其次, 不断的按照 x->y->x 进行二分. 即黑粗线(貌似也没有多粗)是先对x轴进行二分(分成2个空间). 然后对分成的左边和右边在y方向用黑细线进行二分(分成四个空间).最后对这四个空间分别在x轴用黑虚线进行二分,分成了8个空间. 这里的二分其实是取中位值(即大约一半小于它,一半大于它)

最后, 将上面的过程用树刻画出来,得到的就是该二维数据集的2d树

kd树的每个节点都是对当前空间的二分. 例如(7,2)是针对整个数据集进行二分,但是二分的依据是x坐标,(5,4)是针对左边数据集的二分,但是二分的依据是y坐标. 以此类推.

有了kd树的基本认识,我们来ac这道题目.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include "stdafx.h"
#include <iostream>
#include <queue>
#include <stack>
#pragma warning(disable:4996)
#define LOCAL
#define square(x) ((x)*(x))
using namespace std;

const int MAX_N = 55555;

int n,k, depth; // n是顶点的个数, k是维数, depth 是树的深度,depth 仅仅在构建kd树的时候使用, 在查询的时候不用(因为kd树已经建立起来了)

struct KDNode
{
int x[7]; // x[0,...k-1]是该 kd树节点的坐标
bool flag; // 是否是数据(即是否是节点数据, false表示不是节点数据, true表示是节点数据,这个用于运行kd树的查询算法)
KDNode():flag(false){}
bool operator <(const KDNode & node) const
{
return x[depth % k] < node.x[depth %k]; // 越大优先级越高
}
} kdTree[MAX_N << 2], input[MAX_N]; // kd树节点, 这里 kdTree 是kd树(这里使用数组实现二叉树结构,注意,因为使用的是递归构建, 所以不是完全二叉树, 有些博客写错了),input是输入数据, 索引0不用于存储数据

priority_queue<pair<double, KDNode>> pq; // 优先队列, 注意, pq默认是大根堆, 而这里的pair的first是该kd树节点到目标顶点的欧式距离. second域就是该kd树节点

void buildKD(int low, int high,int index, int dep) // 使用input[low, high]构建kd树, 递归构建(和线段树很像)
{
if (low>high) return;
depth = dep %k; // 注意, depth一定要同步变化, 不然的话, KDNode之间的比较标准一直是第一维度,就嗝屁了, 一开始不停的WA 就是这个原因!
int mid = (low+high)>>1;
nth_element(input+low, input+mid, input+high+1);// 让input[low, input+high] 从小到大排在 mid-low+1的元素恰在其位 , 注意,该api对input进行了一定的调整,注意,随着dep的不同, 判定大小的标准也在变化
kdTree[index] = input[mid]; // 复制节点
kdTree[index].flag = true; // 是数据节点,构建根节点
buildKD(low, mid-1, index<<1, dep+1); // 构建左子树
buildKD(mid+1, high, (index<<1)|1, dep+1); // 构建右子树
}

void queryKD(KDNode target, int m, int index, int dep) // kd树查询m-邻近算法, target是目标, index是当前kdTree节点, dep是当前深度
{
if (!kdTree[index].flag) // 如果不是数据节点的话, 直接返回,这是递归的出口
{
return;
}
KDNode cur = kdTree[index]; // 当前kd树的顶点
double dis = 0;
for (int i = 0; i<k;i++)
{
dis += square(kdTree[index].x[i] - target.x[i]);
} // 得到当前节点和目标的欧式距离
bool flag = false; // 是否要探索target 所在cur的一边的镜像对称面.
pair<double, KDNode> curpair(dis, cur);
if (target.x[dep%k] < cur.x[dep%k]) // 如果在当前节点左侧的话
{
queryKD(target, m, index<<1,dep+1); // 继续探索左边
}
else
{
queryKD(target, m, (index<<1)|1, dep+1);
}
if (pq.size() < m)
{
pq.push(curpair);
flag = true; // 那肯定要探索啊
}
else
{
if (dis < pq.top().first) // 如果小于, 则换掉顶部, 注意, 这里pq是一个大根堆(而且已经有m个元素了, 即m个元素的大根堆), 即最大的放上面, 来的如果比顶部小,则抛弃最大的, 换小的. 如果来的大,则忽略.这样搞,最后就是从小到大m个元素(这种思想在堆排序中用过)
{
pq.pop();
pq.push(curpair);
}
if (square(cur.x[dep%k] - target.x[dep%k]) < pq.top().first) // 如果在当前分割超平面上的距离严格小于最大堆堆顶的距离的话, 就有探索镜面对称的那一边的必要, 否则根本没必要(这属于递归的剪枝)
{
flag = true;
}
}
// 最后来探索镜面对称的空间
if (flag)
{
if (target.x[dep%k] < cur.x[dep%k]) // 如果target在当前节点的左边,则就要探索右边(镜面对称嘛)
{
queryKD(target, m, (index<<1)|1, dep+1);
}
else
{
queryKD(target, m, index << 1, dep+1);
}
}
}


int main()
{
#ifdef LOCAL
freopen("d:\\data.in","r",stdin);
freopen("d:\\my.out", "w", stdout);
#endif
while(~scanf("%d%d", &n, &k)) // 注意, 题目有多组输入
{
for (int i = 1; i<=n;i++) // 得到输入的数据,注意,这里不能一边输入一边构建树,因为我们要选择中位数来做分点(不然kd树亦能构建,只是不平衡). 在得到全部输入之前你是不知道中位数长啥样的, 所以必须要先兜着
{
for(int j = 0; j<k;j++)
{
scanf("%d", &input[i].x[j]);
}
}
memset(kdTree, 0, sizeof(kdTree)); // 切记清空, 因为上一次得到的flag可能一些变成了true, 要抹掉
buildKD(1, n, 1, 0); // 构建kd树, 对本题而言,构建一次就够了
int t;
scanf("%d", &t); // 查询的个数
while(t--)
{
KDNode target; // 目标节点
int m; // m-邻近节点
for (int i = 0; i<k;i++) // 输入目标顶点的坐标
{
scanf("%d", &target.x[i]);
}
scanf("%d", &m); // 求m-邻近
queryKD(target,m,1,0); // 此算法结束之后, pq中存放的就是m个距离target最小的点
printf("the closest %d points are:\n",m);
stack<KDNode> s;
while(!pq.empty()) // 注意, pq是大根堆, 所以出堆应该是倒序的, 而题目要求由近及远输出, 所以要用一个栈兜着,再输出
{
s.push(pq.top().second);
pq.pop();
}
while(!s.empty())
{
KDNode n = s.top();
s.pop();
for(int i = 0; i< k; i++)
{
if (i) printf(" ");
printf("%d", n.x[i]);
}
puts("");
}
}
}
return 0;
}

最后ac的结果是

Status Accepted
Time 2028ms
Memory 10444kB
Length 3548
Lang C++
Submitted 2019-06-04 11:12:58
RemoteRunId 29382610

我们来解释一下上面的算法

首先解释一下为什么输入数组的长度是N,但是构建的二叉树的空间是4N? 因为最坏情况下构建kd树的过程如下图所示

仅有四个顶点,但是却要用15个数组空间. 对于线段树而言,也是一样的.

buildKD 方法用于递归构造KD树. 这里使用数组的方式实现二叉树. 注意, 在算法竞赛中,使用数组而不是链表实现树结构是常见的方法. 因为使用的是递归. 所以最后的树不是完全二叉树,而可能就是一棵普通的二叉树,因为取了中间值(nth_element)的缘故,所以是一棵平衡二叉树. 对搜索性能是有保证的. 而queryKD方法使用了优先队列来维护到目标点的距离.

最后来讲一下为什么要搜索对面的(即queryKD中的flag是怎么回事?)

A是target. B是cur点. 则目前优先队列中有的点是B和C,但是实际上D离A更近. 所以D是要考察的. 怎么判断D需要进入优先队列呢? 这里使用了一个准则

如果A到B的当前维(图中的Y方向)的距离<m大小的优先队列中的堆顶元素的话(即queryKD的73行代码), 则我们是要到B的另一半孩子节点上进行搜索的. 即queryKD方法中的最后if(flag)判断.