洛谷 P4357 [CQOI2016]K远点对 kd树

缘起

继续来找 平面最近点对 只是这次是求的是第k远的点对. 洛谷 P4357 [CQOI2016]K远点对

分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
平面给你N个点, 求欧式距离下第K远的点对. 输出他们距离欧式距离的平方(一定是一个整数)

【输入】
首先是一个N,然后是一个K, 然后N行,每行两个整数X和Y,表示一个点的坐标
对于100%的测试点,N <= 100000,1 <= K <= 100,K <= \frac {N(N+1)}{2},0 <= X,Y <=
2^{31}

【输出】
输出第K远的距离.

【样例输入】
10 5
0 0
0 1
1 0
1 1
2 0
2 1
1 2
0 2
3 0
3 1

【样例输出】
9

【限制】
3s

【1】和【2】中讲解了平面最近点对的求法, 用的是分治。但是本题求的是距离第k远的点对. 所以分治不再有效. 本题要使用kd树.

首先讲清楚一下题意——第k远的含义是,距离从大到小排序,第k个距离. 所以要用大小为k的小根堆维护(里面是最远的k个距离)。 每来一个距离,如果比堆顶小,则肯定比堆中所有元素小,肯定不会是最远的k个距离. 如果比堆顶大,则堆顶出堆. 然后将新来的插入堆中. 这是堆求kth问题的老套路了. 最后的答案就是堆顶元素.

本题怎么做呢? 首先对N个点建2维kd树. 这没什么好说的。然后对每个顶点求其他点与它的距离——将有希望的顶点加入小跟堆中. 至于什么叫做有希望的, 后面我们细讲. 这样的话, 因为你计算A点的时候, 将AB这段距离加入了堆中的话, 则计算B的时候也会将BA加入堆中.所以堆的大小不应该是k,而应该是2k.

等等~ 枚举每个点? 老兄, 你是不是忘了N的规模是10w~ 所以每个点在kd树上跑的时间不能太长. 不然的话, 一定T.

而回想对于一个点在kd树上运行的过程是这样的. 首先从根节点(即当前节点的初始值)进入kd树. 然后搜索目标节点在在当前节点的一侧, 然后有选择的搜索另一侧(即kd树的剪枝). 抽象出来就是

1
考察当前节点, 然后搜当前(kd子)树的一侧, 再搜当前(kd子)树的另一侧

额~ 感觉是不是说了废话——所有的树几乎不都是这样的吗? 注意,整个递归过程中,考察当前节点时其实就是得到了一组点对的距离. 然后就要按照之前讲的堆求kth的套路. 应该考察此点对能否加入到堆中. 理所当然的,我们巴不得一开始堆中恰好就是最大的2k(为什么是2k个之前说了)个距离. 则后面来的任何递归时当前点的考察得到的距离都无法进堆而直接pass掉了(只有比堆顶元素大才能进堆并且将堆顶元素挤掉). 这样复杂度会很低. 所以我们自然希望较大的距离能尽快进堆. 所以处理完当前节点之后——

我们该首先搜那一侧?

自然我们希望kd树的节点维护了某种属性,使得我们能粗略估计该kd子树表示的平面的一侧最远能到哪里?

根据上图, kd树左侧的能延展到最远点和目标节点的距离比kd树右侧的能延展到的最远点和目标节点的距离要远, 所以如果脑袋正常的话, 肯定会优先搜索kd树左侧. 所以我们就知道了,kd树节点上需要维护k维坐标每一维坐标的极值. 就本题而言, 每个kd树的节点需要维护 mi[0,1]和mx[0,1], mi[0](mx[0])是此kd子树上所有节点的x方向的最小(大)值, mi[1](mx[1])是此kd子树上所有节点的y方向的最小(大)值.

维护这些属性在建树的时候递归就可以维护. 而维护这些属性的目的是能在决定搜左边还是搜右边之前先计算一下目标节点到kd树左侧最远的粗略估计(下简称估价函数,类似于启发式函数)以及目标节点到kd树右侧最远的粗略估计. 哪个远就优先搜哪个子树. 因为它最有可能产生很远的点对. 进而缩短搜索的次数降低复杂度.

而一定要搜吗? 这就涉及kd树搜索的剪枝. 如果估价函数都没有当前堆顶的元素大.那么该kd子树整棵都不必去搜索了. 原因是显然的. 否则才去搜索.

所以这就是整道题目使用kd树的解法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//#include "stdafx.h"
#include <stdio.h>
#include <algorithm>
#include <queue>
using namespace std;
#define SQUARE(x) ((x)*(x))
#define ABS(a,b) ((a)>(b)?((a)-(b)):((b)-(a)))
//#define LOCAL
typedef long long LL;
const int maxn = 100005, d = 2;
int n,k, depth;

priority_queue<LL, vector<LL>, greater<LL> >pq; // 小根堆

struct KdTreeNode
{
int x[2];
int mi[2], mx[2]; // kd树节点维护的属性
bool flag;
bool operator<(const KdTreeNode &o) const
{
return x[depth]<o.x[depth];
}
}kdTree[maxn<<2], input[maxn];

void updata(int cur)
{
kdTree[cur].mi[0] = kdTree[cur].mx[0] = kdTree[cur].x[0];
kdTree[cur].mi[1] = kdTree[cur].mx[1] = kdTree[cur].x[1]; // 初始化
if (kdTree[cur<<1].flag) // 如果有左子树
{
kdTree[cur].mi[0] = min(kdTree[cur<<1].mi[0],kdTree[cur].mi[0]);
kdTree[cur].mi[1] = min(kdTree[cur<<1].mi[1],kdTree[cur].mi[1]);
kdTree[cur].mx[0] = max(kdTree[cur<<1].mx[0],kdTree[cur].mx[0]);
kdTree[cur].mx[1] = max(kdTree[cur<<1].mx[1],kdTree[cur].mx[1]);
}
if (kdTree[cur<<1|1].flag) // 如果有右子树
{
kdTree[cur].mi[0] = min(kdTree[cur<<1|1].mi[0],kdTree[cur].mi[0]);
kdTree[cur].mi[1] = min(kdTree[cur<<1|1].mi[1],kdTree[cur].mi[1]);
kdTree[cur].mx[0] = max(kdTree[cur<<1|1].mx[0],kdTree[cur].mx[0]);
kdTree[cur].mx[1] = max(kdTree[cur<<1|1].mx[1],kdTree[cur].mx[1]);
}
}

void build(int begin , int end, int cur, int dep)
{
if (begin>end) return;
depth = dep%d;
int mid = begin+end>>1;
nth_element(input+begin, input+mid, input+end+1);
kdTree[cur]=input[mid];
kdTree[cur].flag =true;
build(begin, mid-1, cur<<1, dep+1);
build(mid+1, end, cur<<1|1, dep+1);
updata(cur); // 利用递归返回的子节点的mi/mx数据更新当前节点的mi和mx
}

LL dis(KdTreeNode &a, KdTreeNode &b)
{
return (LL)(SQUARE((LL)a.x[0]-(LL)b.x[0])+SQUARE((LL)a.x[1]-(LL)b.x[1]));
}

LL h(KdTreeNode &a, KdTreeNode &target) // kd树节点a的启发式函数
{
return SQUARE((LL)max(abs(target.x[0]-a.mi[0]), abs(target.x[0] - a.mx[0]))) + SQUARE((LL)max(abs(target.x[1]-a.mi[1]), abs(target.x[1] - a.mx[1]))); // 这里的LL强转一定要注意,很容易WA
}

void query(KdTreeNode &target, int cur)
{
KdTreeNode curNode = kdTree[cur];
if (!curNode.flag)
{
return;
}
LL t = dis(curNode, target);
if (t>pq.top())
{
pq.pop();
pq.push(t);
} // kth的套路
LL lh=0,rh=0; // lh是左子树的启发式函数, rh是右子树的启发式函数
if (kdTree[cur<<1].flag) // 如果有左子树
{
lh = h(kdTree[cur<<1], target); // 计算左子树的启发式函数值
}
if (kdTree[cur<<1|1].flag) // 如果有右子树
{
rh = h(kdTree[cur<<1|1], target); // 计算右子树的启发式函数值
}
if (lh>rh) // 优先搜索左子树
{
if (lh>pq.top()) // 如果有必要搜索左子树
{
query(target, cur<<1); // 搜索左子树
}
if (rh>pq.top()) // 如果有必要搜索右子树
{
query(target, cur<<1|1); // 搜索右子树
}
}
else // 优先搜索右子树
{
if (rh>pq.top()) // 如果有必要搜索右子树
{
query(target, cur<<1|1); // 搜索右子树
}
if (lh>pq.top()) // 如果有必要搜索左子树
{
query(target, cur<<1); // 搜索左子树
}
}
}

int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
//freopen("d:\\my.out", "w", stdout);
#endif
scanf("%d%d", &n,&k);
for (int i = 1; i<=n;i++)
{
scanf("%d%d", &input[i].x[0], &input[i].x[1]);
}
build(1,n,1,0);
k<<=1; // k扩倍
for (int i = 0; i<k; i++) // 初始化小根堆, 一开始里面就有k个元素, 这样就不必在后面添加元素的时候判断堆中元素是否满了k了, 这是一个编码的优化
{
pq.push(0);
}
for (int i = 1; i<=n; i++) // 虽然input 已经被nth_element改的乱七八糟, 但是还是可以用的
{
query(input[i], 1);
}
printf("%lld", pq.top());
return 0;
}

注意, 本题不能用double, 会丢失精度的~ 例如很大的long, 它给你写成2.xxxxxxxe17 这种

ac情况

1
2
3
4
5
6
所属题目
P4357 [CQOI2016]K远点对
评测状态
Accepted
评测分数
100

ps: 此题还有一些大神指出还有 旋转卡壳+凸包 的做法, 唉不说了,现在计算几何几乎还是个小白~ 算法博大精深~

本题还有空间复杂度优化的kd树的写法, 但是基于保持kd树代码风格的一致性. 我仅仅copy它的代码. kd树的空间复杂度优化的核心思想就是不在 maxn<<2, 即i的后代不再是i<<1和 i<<1|1, 而是节点的个数++(下面代码第43行)这样搞. 代价就是节点中必须要有一个ls和一个rs表明该节点的子节点索引.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include<cstdio>
#include<algorithm>
#include<queue>
#define int long long
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
using namespace std;
const int MAXN = 200001, INF = 1e9 + 10;
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
char c = getchar(); int x = 0, f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
int N, K;
priority_queue<int, vector<int>, greater<int> > q;
int root, WD, cur = 0;
#define ls(k) T[k].ls
#define rs(l) T[k].rs
struct Point {
int x[2];
bool operator < (const Point &rhs) const {
return x[WD] < rhs.x[WD];
}
}p[MAXN];
struct KDtree {
int ls, rs, mi[2], mx[2];
Point tp;
}T[MAXN]; // 注意, 这里是MAXN而不是MAXN<<2,根本原因在于这里每个节点多了ls和rs(左右子树根节点的索引),所以可以使用第43行的cur++来做kd树的节点的索引了
inline int sqr(int x) {
return x * x;
}
void update(int k) {
for(int i = 0; i <= 1; i++) {
T[k].mi[i] = T[k].mx[i] = T[k].tp.x[i];
if(ls(k)) T[k].mi[i] = min(T[k].mi[i], T[ls(k)].mi[i]), T[k].mx[i] = max(T[k].mx[i], T[ls(k)].mx[i]);
if(rs(k)) T[k].mi[i] = min(T[k].mi[i], T[rs(k)].mi[i]), T[k].mx[i] = max(T[k].mx[i], T[rs(k)].mx[i]);
}
}
int Build(int l, int r, int wd) {
if(l > r) return 0;
WD = wd;
int k = ++cur, mid = l + r >> 1;
nth_element(p + l, p + mid, p + r + 1);
T[k].tp = p[mid];
T[k].ls = Build(l, mid - 1, wd ^ 1);
T[k].rs = Build(mid + 1, r, wd ^ 1);
update(k);
return k;
}
int dis(Point a, Point b) {
return sqr(a.x[0] - b.x[0]) + sqr(a.x[1] - b.x[1]);
}
int GetMaxDis(KDtree a, Point b) {
int rt = 0;
for(int i = 0; i <= 1; i++)
rt += sqr(max(abs(b.x[i] - a.mi[i]), abs(b.x[i] - a.mx[i])));
return rt;
}
void Query(int k, Point a) {
int tmp = q.top(), tmpdis = dis(T[k].tp, a);
if(tmpdis > tmp) q.pop(), q.push(tmpdis);
int disl = -INF, disr = -INF;
if(ls(k)) disl = GetMaxDis(T[ls(k)], a);
if(rs(k)) disr = GetMaxDis(T[rs(k)], a);
if(disl > disr) {
if(disl > q.top()) Query(ls(k), a);
if(disr > q.top()) Query(rs(k), a);
}
else {
if(disr > q.top()) Query(rs(k), a);
if(disl > q.top()) Query(ls(k), a);
}
}
main() {
#ifdef WIN32
freopen("a.in", "r", stdin);
#endif
N =read(); K = read();
for(int i = 1; i <= N; i++)
p[i].x[0] = read(), p[i].x[1] = read();
root = Build(1, N, 0);
for(int i = 1; i <= 2 * K; i++)
q.push(0);
for(int i = 1; i <= N; i++)
Query(root, p[i]);
printf("%lld", q.top());
}

参考

【1】https://yfsyfs.github.io/2019/09/22/hdu-1007-Quoit-Design-%E6%9F%A5%E6%89%BE%E5%B9%B3%E9%9D%A2%E6%9C%80%E8%BF%91%E7%82%B9%E5%AF%B9/

【2】