k邻近算法 KNN kd树

缘起

开始学习机器学习十大经典算法(【1】)之 knn,它是有监督学习算法的一种. 所谓有监督指的是, 训练样本会告诉你输出. 即给你一张图片 告诉你, 它是喵还是汪~ 而诸如kmeans算法,它不会有一个输出, 就是无监督学习.

分析

什么是knn?

KNN(k-nearest neighbor, k近邻法)就是训练数据集有N个 ,每个都是d维的(所以这N个数据可以视作d维空间中的N个点). 然后他们的类别(或者称标签)结果是知道的. 然后来了一个新的数据(视作d维空间中的点P),要你判断该数据的类别. 那么knn的思想是极为自然的,就是取这N个点中距离P最近的k个点. 然后看看这k个点的标签. 例如这k个点中贴A标签的点是最多的. 那么我们就认为P是A标签的可能性最大. 所以自然的把A标签贴在P上.

敏锐如你注意到了knn算法的三个要点

  1. 度量的选择. 例如欧式度量, 曼哈顿度量、切比雪夫度量(即L^\inf 范数)
  2. k的选择. k越小,则模型越简单(即越武断), k越大, 模型越复杂(即越广纳言路)
  3. 最后如何根据k个点的标签选择P的标签. 即投票规则.

虽然knn有上面三个选择. 但是实际上敏锐如你一定注意到了,其实knn的核心就在于如何快速遴选出距离P最近的k个点. 关于这个,【2】中已经论述过了——使用kd树. 而且【3】中我们也论述了对于倾斜的数据如何基于方差来优化kd树. 本文并不打算使用方差优化, 因为【3】中已经论述过它是怎么做的了. 而且做法是trival的. 并且kd树还有空间复杂度优化的做法(【4】),本文也不打算用那种代码实现(基于代码风格的一致). 本文就想根据【2】来写. 代码后面会奉上. 不着急, 我们先凭借上面讲的对knn的基本认识来讲讲knn在实际场景时的一些优化.

ps: 如果实例点是随机分布的,kd树搜索的平均计算复杂度是 O(logN) ,N 为训练数据集的大小,kd树更适用与训练数据集大小远大于空间维数的d的knn问题,当空间维数d接近训练数据集大小时,它的效率将迅速下降,几乎接近线性扫描。但是实际应用中,训练数据集是海量的, 所以这种情况一般不会发生. 所以kd树算法对于knn问题还是十分有效的.

knn在实际场景中的优化

优化1: 样本倾斜

盗图一张, 上图中的Y明明更有可能是红色,但是如果使用”谁票数多就选谁”的投票规则的话,则Y会被错误识别成蓝色. 出现这种问题的根源在于训练样本是倾斜的. 解决也比较容易. 只需要给投的票用距离加权即可——即离Y越远你投的票的权重越低. 而关于距离的权重函数一般有如下两种

  1. weight = 1 / (distance + const)
  2. 正态分布的密度函数——高斯函数

在本文给出的代码中都不会写,只会采用最朴素的”谁票数多就选谁”的投票规则. 因为这些都是封装到投票规则中的. 都是trival的. 而我们只是抓住主线. 但是优化的思想讲还是要讲出来的.

优化2 海量数据

前面说了我们一般使用kd树处理knn. 但是kd树依旧对于高维数据集(例如一条数据可能有上百个字段)的处理乏力.

kd-tree缺点:对于20维以下的数据效果较好,但是对于高维数据,处理困难

解决方法:ball-tree(球树)
ball-tree:专门针对kd-tree对高维数据处理困难提出的

kd-tree和ball-tree对海量数据可能处理还是不够,因此提出LSH

LSH局部敏感哈希,专门针对海量高维数据提出的KNN优化算法

即 ball-tree是针对kd树对于高维数据集的处理乏力而诞生的, 而LSH是进一步针对数据量而提出的优化算法.

#####

本文将给出ball-tree和kd-tree的算法(暴力算法就不写了,大家都知道的). LSH 暂时不弄.

ball tree 算法分成

  1. 建树

    1) 先构建一个超球体,这个超球体是可以包含所有训练数据集的最小球体。

    2) 从球中选择一个离球的中心最远的训练集的点A,然后选择第二个训练集点B离A最远,将球中所有的其他训练集点分配到离A或B最近的一个上(即如果离A近,就和A归为一类, 反之和B归为一类),然后计算每个聚类的中心,以及聚类能够包含属于该聚类的所有训练数据点所需的最小半径。这样我们得到了两个子超球体,和KD树里面的左右子树对应。即球树中的节点是一个个的超球体. 球体的中心未必是训练数据集中的点, 但是每个球树节点维护了属于当前球的训练数据集列表.

    3)对于这两个子超球体,递归执行步骤2). 最终得到了一个球树。 递归出口是如果超球体中只有2个顶点的话, 则以此2点的中心做球返回即可.

    具体见下图, (b)图中的节点的数字表示属于该球的训练数据集的大小.

  2. 查询

    使用球树找出给定目标点的最近邻方法是首先自上而下贯穿整棵树二分找出包含目标点所在的叶子节点,并在这个叶子球里找出与目标点最邻近的点(需要遍历叶子球中所有的训练数据点),这将确定出目标点距离它的最近邻点的一个上限值,然后跟KD树查找一样,检查兄弟结点,如果目标点到兄弟结点的球心的距离>=兄弟结点的半径与当前的上限值之和,那么兄弟结点里不可能存在一个更近的点(此即ball树的剪枝,原理本质和kd树的剪枝是一样的);否则的话,必须进一步检查位于兄弟结点以下的子树。

    检查完兄弟节点后,我们向父节点回溯,继续搜索最小邻近值。当回溯到根节点时,此时的最小邻近值就是最终的搜索结果。 然后将此最小邻近点哈希掉, 表示已经搜过了(下次就不搜了). 然后跑k遍此球树算法就得到k邻近点.

关于ball树的剪枝原理, 一图胜千言

上图P作为兄弟节点对应的球中的节点是不可能和目标节点产生更短的距离的. 从ball树剪枝的原理和kd树剪枝的原理对比就能看出为什么ball树比kd树在高维要快. 网上论述的”kd树用的是矩形, 带角, 可能产生不必要的搜索”这种个说法我个人感觉是人云亦云. 我的理解如下

1
2
3
kd树的剪枝的短板在于它只用了一个维度!而和它比较的距离是全部维度 具体可以参见【2】的代码第73行.  这样剪枝对于维度一旦高起来(例如超过20维),这种剪枝形同空气~ 即很难剪枝生效.

而ball树的剪枝用到了全部的维度! 所以对于训练数据集维度较大剪枝效果很好.

其实大部分网上的博文(不论是kd树还是ball树)都是使用上述算法进行查询的——开一个哈希数组,跑k遍算法. 但是我习惯的写法是【2】中的那样——维护一个k大小的关于距离的大根堆. 伊始里面全是0. 然后跑一遍kd树或者ball树算法即可.

这里值得注意的细节是ball树的建树方法和kd树建树的过程是不一样的. 确切讲就是kd树的节点是训练数据集本身. 而ball树的节点是球,而且最核心的是球心未必是训练集中的点. 所以造成的影响就是kd树的查询算法是从根节点开始先序遍历(因为节点是训练数据集中的点, 所以每个节点只需要计算一次距离,而ball树需要遍历一个节点中所有数据集). 而ball树的查询算法是从叶子节点开始的(因为不可能从根节点开始,否则的话, 根节点拥有的测试数据集是全集, 遍历一遍太耗时了, 换言之——如果你有那闲工夫遍历一遍的话, 那都算出来了, 还建个屁ball树啊?).

最后我们谈谈ball树的建树过程. 我们谈到每次都要确定一个能把当前训练数据集都罩住的球. 弱弱的问一句: 这样不会是耗时为O(n)吗? 是的. 那为什么要建树呢? 有那闲工夫, 我早把target的k邻近求出来的. 但是别忘了, 你那是对一个target. 我这里出功出力对训练数据集建立一棵ball树. 这棵ball树是可以对后面所有target都可以用的. 所以是值得的.

程序的测试数据以及测试数据的意义见附录

KNN算法的扩展

  1. 限定半径

    有时候我们会遇到这样的问题,即样本中某系类别的样本非常的少,甚至少于K,这导致稀有类别样本在找K个最近邻的时候,会把距离其实较远的其他样本考虑进来,而导致预测不准确。为了解决这个问题,我们限定最近邻的一个最大距离,也就是说,我们只在一个距离范围内搜索所有的不超过k近邻(一言以蔽之就是宁缺毋滥~),这避免了上述问题。这个距离我们一般称为限定半径。

  2. 最近质心推定

    这个算法比KNN还简单。它首先把样本按输出类别归类。对于第 L类的C个样本。它会对这C个样本的d维特征中每一维特征求平均值,最终该类别所有维度的d个平均值形成所谓的质心点。对于样本中的所有出现的类别,每个类别会最终得到一个质心点。当我们做预测时,仅仅需要比较预测样本和这些质心的距离,最小的距离对于的质心类别即为预测的类别。这个算法通常用在文本分类处理上。

KNN小结

 KNN算法是很基本的机器学习算法了,它非常容易学习,在维度很高的时候也有很好的分类效率,因此运用也很广泛,这里总结下KNN的优缺点。

    KNN的主要优点有:

    1) 理论成熟,思想简单,既可以用来做分类也可以用来做回归

    2) 可用于非线性分类

    3) 训练时间复杂度比支持向量机(SVM)之类的算法低,仅为O(n)

    4) 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感

    5) 由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合

    6)该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分

    

    KNN的主要缺点有:

    1)计算量大,尤其是特征数非常多的时候

    2)样本不平衡的时候,对稀有类别的预测准确率低

    3)KD树,球树之类的模型建立需要大量的内存

    4)使用懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢

    5)相比决策树模型,KNN模型可解释性不强

C++实现

kd树实现knn (vs2010调试通过) 代码没写注释, 详见【2】
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include "stdafx.h"
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <queue>
#include <math.h>
using namespace std;
#define LOCAL
#define SQUARE(x) ((x)*(x))
#define DBL_MAX 1.7e300
#include <time.h>

const int maxn = 150000, maxd = 100, maxt = 10; // 训练集最大数据量为15万, 最多支持20维数据, 最多10种类别
int n,d,k,m,depth;
int v[maxt];

struct KdTreeNode
{
double x[maxd];
int type; // 此数据属于的类别
bool flag;
bool operator<(const KdTreeNode &o) const
{
return x[depth]<o.x[depth];
}
}kdTree[maxn<<2], input[maxn];

typedef pair<double, int> P; // first是欧式距离, second是在kdTree中的下标索引

priority_queue<P> pq;

void build(int begin, int end, int cur, int dep)
{
if (begin>end) return;
depth = dep%d;
int mid = begin+end>>1;
nth_element(input+begin, input+mid, input+end+1);
kdTree[cur] = input[mid];
kdTree[cur].flag = true;
build(begin, mid-1, cur<<1, dep+1);
build(mid+1,end, cur<<1|1, dep+1);
}

double edis(KdTreeNode &target, KdTreeNode &curNode)
{
double ans = 0;
for (int i = 0; i<d; i++)
{
ans+=SQUARE(target.x[i]- curNode.x[i]);
}
return sqrt(ans);
}

void handlecur(KdTreeNode &target,int cur)
{
double dis = edis(target, kdTree[cur]);
if (dis<pq.top().first)
{
pq.pop();
pq.push(P(dis, cur));
}
}

void query(KdTreeNode &target, int cur, int dep)
{
KdTreeNode curNode = kdTree[cur];
if (!curNode.flag) return;
handlecur(target, cur); // 处理当前节点
int idm = dep%d;
int tx = target.x[idm]<curNode.x[idm]?(cur<<1):(cur<<1|1);
int ty = tx^1;
query(target, tx, dep+1); // 搜索target所在的一侧
if (abs(target.x[idm]-curNode.x[idm])<pq.top().first) // kd树剪枝. 只用到了idm这一维度, 而pq.top().first用到了全部d个维度,所以d一大, kd树的剪枝效率堪忧
{
query(target, ty, dep+1); // 搜索target的另一侧
}
}

int vote()
{
memset(v, 0, sizeof(v));
while(!pq.empty())
{
P top = pq.top();
pq.pop();
v[kdTree[top.second].type]++;
} // 统计k邻近属于哪个类别分别有多少个点
int ans = 0, ans_max = 0;
for (int i = 1; i<maxt; i++)
{
if (ans_max<v[i])
{
ans_max = v[i];
ans = i;
}
}
return ans;
}

int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
freopen("d:\\my.out", "w", stdout);
#endif
clock_t start = clock();
scanf("%d%d%d%d", &n, &d, &k, &m);
for (int i = 1; i<=n; i++)
{
for (int j = 0;j<d; j++)
{
scanf("%lf", &input[i].x[j]);
}
scanf("%d", &input[i].type);
}
build(1,n,1,0);
while(m--)
{
for (int i = 0; i<k; i++)
{
pq.push(P(DBL_MAX,0));
} // 每次预测样本都要初始化优先队列
KdTreeNode target;
for (int i = 0; i<d; i++)
{
scanf("%lf", &target.x[i]);
}
query(target, 1, 0);
printf("%d\n", vote()); // 投票推断
}
clock_t end = clock();
printf("任务耗时%ld毫秒.\n", end-start);
return 0;
}

测试效果(n、d、k、m的含义见附录)

1
2
硬件环境: 8G内存, Windows10 64位, i7-8750H 2.2GHz 
n d k m 为 1000 2 10 1w 训练时间大概是3秒.
ball树实现knn (vs2010调试通过)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#include "stdafx.h"
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <time.h>
#include <queue>
#include <vector>
#include <math.h>
using namespace std;
#define LOCAL
#define SQUARE(x) ((x)*(x))
#define DBL_MAX 1.7e300
typedef pair<double, int> P;
const double eps = 1e-8;
const int maxn = 150000, maxd = 100, maxt = 10;
int n,d,k,m, cnt,v[maxt]; // cnt是ball tree节点个数
typedef vector<int>::iterator vit;

struct BallTreeNode
{
double x[maxd],r; // 球心和半径
bool flag;
vector<int> g; // g中装填着属于此ball树节点的训练集点编号
}ballTree[maxn<<2];

priority_queue<P> pq;

struct Input
{
double x[maxd];
int type;
}input[maxn];

double edis(double bn[], double bi[])
{
double ans=0;
for (int i = 0; i<d; i++)
{
ans+=SQUARE(bn[i]-bi[i]);
}
return sqrt(ans);
}

void kk(BallTreeNode &bn, int &a, int &b) // 本函数的作用在于根据当前球包含的训练数据集确定球树节点bn的球心和半径, 以及建树第二步的a和, 从这个函数就能看出ball树建树的成本是比较高的
{
bn.r = -1;
double adis = -1, bdis = -1;
if (bn.g.size()==1) // 如果是单点, 则球无限大
{
bn.r = DBL_MAX;
memcpy(bn.x, input[bn.g[0]].x, sizeof(double)*maxd);
return;
}
int num = bn.g.size(); // num个点
for (int i = 0; i<d; i++)
{
double sum = 0;
for (vit j = bn.g.begin(); j!=bn.g.end(); j++)
{
sum+=input[*j].x[i];
}
bn.x[i] = sum/num;
}
for (vit i = bn.g.begin(); i!=bn.g.end(); i++)
{
double t = edis(bn.x, input[*i].x);
if (t>adis)
{
adis = t;
a = *i;
}
bn.r = max(bn.r, t);
}
for (vit i = bn.g.begin(); i!=bn.g.end(); i++)
{
double t = edis(input[a].x, input[*i].x);
if (t>bdis)
{
bdis = t;
b = *i;
}
}
bn.r+=eps;
}

void classify(int cur, int a, int b) // 将ballTreee[cur].g 分成不交的两部分, 分别进入cur<<1和cur<<1|1这2个子节点去
{
BallTreeNode curNode = ballTree[cur];
for (vit i = curNode.g.begin(); i!=curNode.g.end(); i++)
{
if (edis(input[a].x, input[*i].x)<edis(input[b].x, input[*i].x)) // ballTree[cur].g中距离a较近的点进入左子树, 否则进入右子树
{
ballTree[cur<<1].g.push_back(*i);
}
else
{
ballTree[cur<<1|1].g.push_back(*i);
}
}
}

void build(int cur)
{
int a, b; // 建树第二步的a和b
kk(ballTree[cur],a,b); // 确定囊括input[dep][begin,...,end]的球心和半径, 即这样就确定了当前ball树节点, 并确定建树第二步的a和b
ballTree[cur].flag = true;
if (ballTree[cur].g.size()==1) return; // 叶子节点
classify(cur, a, b); // 初始化左右子树包含的点集
build(cur<<1);
build(cur<<1|1); // 递归建树
}

int vote()
{
memset(v, 0, sizeof(v));
while(!pq.empty())
{
P top = pq.top();
pq.pop();
v[input[top.second].type]++;
}
int ans = 0, ans_max = 0;
for (int i = 1; i<maxt; i++)
{
if (ans_max<v[i])
{
ans_max = v[i];
ans = i;
}
}
return ans;
}

bool inball(Input &target, int cur) // target是否在ballTree[cur]中
{
if (!ballTree[cur].flag)
{
return false;
}
return ballTree[cur].r+eps>edis(ballTree[cur].x, target.x);
}

void handlecur(Input &target, int cur)
{
BallTreeNode curNode = ballTree[cur];
for (vit i = curNode.g.begin(); i!=curNode.g.end(); i++) // 遍历当前节点
{
double dis = edis(target.x, input[*i].x);
if (dis < pq.top().first)
{
pq.pop();
pq.push(P(dis, *i));
}
}
}

void query(Input &target, int cur, int from) // 在以cur为根节点的ball子树上搜索, from 表明它是从哪里来进入的
{
int root = cur; // 保存当前根节点
while(ballTree[cur].flag)
{
bool inl = inball(target, cur<<1), inr = inball(target, cur<<1|1);
if (inl)
{
if (inr) // 如果左边右边都有的话,则选择g规模小的那个节点
{
cur = ballTree[cur<<1].g.size()>ballTree[cur<<1|1].g.size()?(cur<<1|1):(cur<<1);
}
else
{
cur = cur<<1;
}
}
else if(inr)
{
cur = cur<<1|1;
}
else // 不属于任何一个子节点, 那么就到当前节点止步吧~
{
cur<<=1;
break;
}
}
cur>>=1; // 先找到target所在的节点
handlecur(target, cur); // 处理当前节点
while(cur!=root)
{
BallTreeNode sibling = ballTree[cur^1]; // 兄弟
if ((cur^1)!=from && sibling.flag && edis(target.x,sibling.x)<pq.top().first+sibling.r) // edis用到了全部维度, 所以在d比较大的时候, ball树比kd树搜索更有效率,注意这里的from的作用, 就是为了防止从一个节点进入它的兄弟,然后兄弟又重新回到它的自己,这样就死循环了
{
query(target, cur^1, cur); // 跑到兄弟节点上搜索
}
cur>>=1; // 到父节点去
}
}

int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
freopen("d:\\my.out", "w", stdout);
#endif
clock_t start = clock();
scanf("%d%d%d%d", &n,&d,&k,&m);
for (int i = 1; i<=n; i++)
{
for (int j = 0; j<d; j++)
{
scanf("%lf", &input[i].x[j]);
}
scanf("%d", &input[i].type);
}
for (int i = 1; i<=n; i++)
{
ballTree[1].g.push_back(i);
} // 初始化列表
build(1); // 建ball树, ball树是相当浪费内存的. 按照本文cur<<1, cur<<1|1这种写法, 基本ballTree的长度几乎要是输入的长度的10倍才行, 所以应该使用空间复杂度优化的写法, 但是那样写的话, 代码显得又太复杂, 所以就不那样写了
while(m--)
{
while(!pq.empty())
{
pq.pop();
}
for (int i = 0;i<k; i++)
{
pq.push(P(DBL_MAX, 0));
} // 初始化堆
Input target;
for (int i =0; i<d; i++)
{
scanf("%lf", &target.x[i]);
}
query(target, 1, 0);
printf("%d\n", vote());
}
clock_t end = clock();
printf("任务耗时%ld毫秒.\n", end-start);
return 0;
}

测试效果

1
2
硬件环境: 8G内存, Windows10 64位, i7-8750H 2.2GHz 
n d k m 为 1000 2 10 1w 训练时间大概9

由此可见, kd树的编码复杂度比ball树低很多. 而且维数如果不大的话(例如数百维的训练数据集),kd树即可, 不需要使用ball树.

但是可能是我写的ball树太矬了~ ball树的表现不论是高维还是低维都不如kd树. 而且太耗费内存了. 几乎已经到了不可用的地步~ 上面ball树的代码, 训练数据集到达5000, 你开ballTree长度达到60w都会RE. 可见上面写法几乎是不实用的. 遂想按照【4】的写法进行改写.

空间复杂度优化的ball树写法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#include "stdafx.h"
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <time.h>
#include <queue>
#include <vector>
#include <math.h>
using namespace std;
#define LOCAL
#define SQUARE(x) ((x)*(x))
#define DBL_MAX 1.7e300
typedef pair<double, int> P;
const double eps = 1e-8;
const int maxn = 150000, maxd = 100, maxt = 10;
int n,d,k,m, cnt,v[maxt]; // cnt是ball tree节点个数
typedef vector<int>::iterator vit;

struct BallTreeNode
{
double x[maxd],r; // 球心和半径
int lc, rc, fa; // 左孩子、右孩子、父节点
vector<int> g,lg,rg; // g中装填着属于此ball树节点的训练集点编号,lg存储左子树,rg存储右子树
}ballTree[maxn];

priority_queue<P> pq;

struct Input
{
double x[maxd];
int type;
}input[maxn];

double edis(double bn[], double bi[])
{
double ans=0;
for (int i = 0; i<d; i++)
{
ans+=SQUARE(bn[i]-bi[i]);
}
return sqrt(ans);
}

void kk(BallTreeNode &bn, int &a, int &b) // 本函数的作用在于根据当前球包含的训练数据集确定球树节点bn的球心和半径, 以及建树第二步的a和, 从这个函数就能看出ball树建树的成本是比较高的
{
bn.r = -1;
double adis = -1, bdis = -1;
if (bn.g.size()==1) // 如果是单点, 则球无限大
{
bn.r = DBL_MAX;
memcpy(bn.x, input[bn.g[0]].x, sizeof(double)*maxd);
return;
}
int num = bn.g.size(); // num个点
for (int i = 0; i<d; i++)
{
double sum = 0;
for (vit j = bn.g.begin(); j!=bn.g.end(); j++)
{
sum+=input[*j].x[i];
}
bn.x[i] = sum/num;
}
for (vit i = bn.g.begin(); i!=bn.g.end(); i++)
{
double t = edis(bn.x, input[*i].x);
if (t>adis)
{
adis = t;
a = *i;
}
bn.r = max(bn.r, t);
}
for (vit i = bn.g.begin(); i!=bn.g.end(); i++)
{
double t = edis(input[a].x, input[*i].x);
if (t>bdis)
{
bdis = t;
b = *i;
}
}
bn.r+=eps;
}

void classify(int cur, int a, int b) // 将ballTreee[cur].g 分成不交的两部分, 分别进入左右子树
{
for (vit i = ballTree[cur].g.begin(); i!=ballTree[cur].g.end(); i++)
{
if (edis(input[a].x, input[*i].x)<edis(input[b].x, input[*i].x)) // ballTree[cur].g中距离a较近的点进入左子树, 否则进入右子树
{
ballTree[cur].lg.push_back(*i);
}
else
{
ballTree[cur].rg.push_back(*i);
}
}
}

int build(int fa, vector<int> &g) // g是当前节点拥有的列表(即在当前球中的所有训练数据集点)
{
int cur = ++cnt; // 当前节点就是ballTree[cur]
int a, b; // 建树第二步的a和b
ballTree[cur].fa = fa;
for (vit i = g.begin(); i!=g.end(); i++)
{
ballTree[cur].g.push_back(*i); // 搞当前节点的g
}
kk(ballTree[cur],a,b); // 确定囊括input[dep][begin,...,end]的球心和半径, 即这样就确定了当前ball树节点, 并确定建树第二步的a和b
if (ballTree[cur].g.size()==1) return cur; // 叶子节点
classify(cur, a, b); // 初始化左右子树包含的点集lg和rg
int lc = build(cur,ballTree[cur].lg); // 建立左子树
int rc = build(cur, ballTree[cur].rg); // 建立右子树
ballTree[cur].lc = lc, ballTree[cur].rc = rc;
return cur;
}

int vote()
{
memset(v, 0, sizeof(v));
while(!pq.empty())
{
P top = pq.top();
pq.pop();
v[input[top.second].type]++;
}
int ans = 0, ans_max = 0;
for (int i = 1; i<maxt; i++)
{
if (ans_max<v[i])
{
ans_max = v[i];
ans = i;
}
}
return ans;
}

bool inball(Input &target, int cur) // target是否在ballTree[cur]中, 注意, 如果不是数据节点的话,r是0
{
if (cur>cnt) return false; // 如果不是数据节点
return ballTree[cur].r+eps>edis(ballTree[cur].x, target.x);
}

void handlecur(Input &target, int cur)
{
BallTreeNode curNode = ballTree[cur];
for (vit i = curNode.g.begin(); i!=curNode.g.end(); i++) // 遍历当前节点
{
double dis = edis(target.x, input[*i].x);
if (dis < pq.top().first)
{
pq.pop();
pq.push(P(dis, *i));
}
}
}

void getsibling(int cur, int &sibling) // 得到ballTree[cur]的兄弟节点,注意, 这里cur不会是ball树的根节点的
{
int fa = ballTree[cur].fa;
sibling = ballTree[fa].lc+ballTree[fa].rc-cur;
}

void query(Input &target, int cur, int from) // 在以cur为根节点的ball子树上搜索, from 表明它是从哪里来进入的
{
int root = cur; // 保存当前根节点
while(ballTree[cur].lc || ballTree[cur].rc) // 只要非叶子
{
bool inl = inball(target, ballTree[cur].lc), inr = inball(target, ballTree[cur].rc);
int lc = ballTree[cur].lc, rc = ballTree[cur].rc;
if (inl)
{
if (inr) // 如果左边右边都有的话,则选择g规模小的那个节点
{
cur = ballTree[lc].g.size()>ballTree[rc].g.size()?rc:lc;
}
else
{
cur = lc;
}
}
else if(inr)
{
cur = rc;
}
else // 不属于任何一个子节点, 那么就到当前节点止步吧~
{
break;
}
} // 最后是叶子节点或者target不属于任何一个子节点
handlecur(target, cur); // 处理当前节点
while(cur!=root)
{
int sibling;
getsibling(cur,sibling);
BallTreeNode siblingNode = ballTree[sibling]; // 兄弟节点
if (sibling!=from && sibling<=cnt && edis(target.x,siblingNode.x)<pq.top().first+siblingNode.r) // edis用到了全部维度, 所以在d比较大的时候, ball树比kd树搜索更有效率,注意这里的from的作用, 就是为了防止从一个节点进入它的兄弟,然后兄弟又重新回到它的自己,这样就死循环了
{
query(target, sibling, cur); // 跑到兄弟节点上搜索
}
cur = ballTree[cur].fa; // 到父节点去
}
}

int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
freopen("d:\\ac.out", "w", stdout);
#endif
clock_t start = clock();
scanf("%d%d%d%d", &n,&d,&k,&m);
for (int i = 1; i<=n; i++)
{
for (int j = 0; j<d; j++)
{
scanf("%lf", &input[i].x[j]);
}
scanf("%d", &input[i].type);
}
for (int i = 1; i<=n; i++)
{
ballTree[0].g.push_back(i);
} // 初始化列表
build(0, ballTree[0].g);
while(m--)
{
while(!pq.empty())
{
pq.pop();
}
for (int i = 0;i<k; i++)
{
pq.push(P(DBL_MAX, 0));
} // 初始化堆
Input target;
for (int i =0; i<d; i++)
{
scanf("%lf", &target.x[i]);
}
query(target, 1, 0);
printf("%d\n", vote());
}
clock_t end = clock();
printf("任务耗时%ld毫秒.\n", end-start);
return 0;
}

测试结果(注意,对于这种优化过后的ball树, 节点仅仅有1w以下的节点, 极大的节约了空间复杂度)

1
2
3
硬件环境: 8G内存, Windows10 64位, i7-8750H 2.2GHz 
n,d,k,m为5000 2 10 10000 ball树耗时 60s左右,kd树是4.4s左右
n,d,k,m为5000 95 10 100, ball树耗时 12s左右, kd树是1.7s左右

唉~ ball树感觉自己写的好搓~ 但是算法给出的答案是正确的.

附录

测试数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
测试数据集示例
说明: 第一行是n、d(数据的维度)、k(k邻近)、m(测试样本数量). 后面n行训练集数据, 最后一列是样本输出
(即类别) 然后是m行d维样本测试数据

691 2 10 200
8.326976 0.953952 3
7.153469 1.673904 2
1.441871 0.805124 1
...
7.083449 0.622589 2
2.080068 1.254441 2
0.522844 1.622458 2
10.362000 1.544827 3
2.89191 0.143349
...
0.146645 1.15478
12.1767 1.90774

PS: 这些数据的背后意义:某个妹纸想要找一个男朋友,她收集了婚恋网上面对她感兴趣男士的一些数据,第一个是一年的飞行里程,第二个是每年玩游戏的时间(10小时为单位),第三个是结果:3表示很感兴趣,2表示感兴趣,1表示没兴趣。 这也说明了knn算法的一大作用在于推断标签上. 而knn的另一大作用是回归——通过找出一个样本的k个最近邻居,将这些邻居的某个(些)属性的平均值赋给该样本,就可以得到该样本对应属性的值。

1
2
3
4
5
6
7
100多年前,有位英国遗传学家(Galton)注意到当父亲身高很高时,他的儿子的身高一般不会比父亲身高更高。同样
如果父亲很矮,他的儿子也一般不会比父亲矮,而会向一般人的均值靠拢。当时这位英国遗传学家将这现象称为回归,
现在这个概念引伸到随机变量有向 回归线 集中的趋势。
即观察值不是全落在回归线上,而是散布在回归线周围。但离回归线越近,观察值越多,偏离较远的观察值极少,这种
不完全呈函数关系,但又有一定数量的关系的现象称回归。回归分为线性回归和非线性回归。
回归分析是为了了解和预测观测值的,通过计算和估计回归方程的系数得到结果。
一般来说线性回归都需要建立回归方程。
makedata程序

手动输入数据对程序进行测试是一件麻烦的事情, 我们可以写个程序随机灌点数据.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include "stdafx.h"
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include<random>

using namespace std;
const int INF = 0x3f3f3f3f;
const int maxn = 1000000;
typedef long long LL;
typedef vector<int>::iterator vit;


int main()
{
freopen("d:\\data.in", "w", stdout);
srand((int)(time(0)));


int n = 1000, d =2, k = 10, m = 10000;
printf("%d %d %d %d\n", n,d,k,m);
default_random_engine e(time(0));
uniform_real_distribution<double> u(0,15);
uniform_real_distribution<double> v(0,2);
for(int i = 0; i < n; ++i)
{
for (int j = 0; j<d; j++)
{
if (j&1)
{
cout << u(e) << ' ';
}
else
{
cout << v(e) << ' ';
}
}
cout << rand()%9+1 << endl;
}
for (int i = 0; i<m;i++)
{
for (int j = 0; j<d; j++)
{
if (j&1)
{
cout << u(e) << ' ';
}
else
{
cout << v(e) << ' ';
}
}
cout<< endl;
}

fclose(stdout);
return 0;
}

参考

【1】https://zhuanlan.zhihu.com/p/59767178

【2】https://yfsyfs.github.io/2019/06/03/hdu-4347-The-Closest-M-Points-KD%E6%A0%91-k-%E9%82%BB%E8%BF%91%E7%AE%97%E6%B3%95/

【3】https://yfsyfs.github.io/2019/09/21/hdu-5809-Ants-kd%E6%A0%91-%E5%B9%B6%E6%9F%A5%E9%9B%86/

【4】https://yfsyfs.github.io/2019/09/20/%E6%B4%9B%E8%B0%B7-P4357-CQOI2016-K%E8%BF%9C%E7%82%B9%E5%AF%B9-kd%E6%A0%91/