c++ stl nth_element 实现

缘起

我们知道c++ stl 中 nth_element 是很好用的——

nth_element(int first, int target, int *end) 作用是让[first, end)中从小到大排名第target-first+1的元素恰在其位

一个简单的例子

1
2
3
4
5
6
7
8
#include<algorithm>
using namespace std;
int main() {
int a[8] = {0, 1, 2, 5, 7, 3, 4, 1};
nth_element(a + 3, a + 4, a + 8); // 让a[3,...7] 中从小到大排在第4-3+1=2位的数字恰在其位(其他数字不保证顺序)
for (int i = 1; i <= 7; i++) printf("%d ", a[i]); printf("\n");
return 0;
}

输出

1
1 2 1 3 4 5 7

注意1 3 这个输出. 这是因为 a[3,…,7]中从小到大排名第 4-3+1=2的元素恰在其位,注意,nth_element并不保证其他元素也有序了.

那么nth_element 这个sdk是怎么实现的呢?

分析

问题化归为输入数组a,求第k大. 可以使用优先队列. 但是这里提供另一种算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include "stdafx.h"
#include <iostream>
#include <cstdlib>
#include <ctime>
#pragma warning(disable:4996)
//#define LOCAL
using namespace std;

void swap(int &a, int &b)
{
if (a ==b) return;
a ^= b;
b ^= a;
a ^= b;
}

int pivot(int *a, int low, int high) // 处理 a[low, high)
{
int ans = low-1;
for (int i = low;i<high;i++) if (a[i] < a[high-1]) swap(a[++ans], a[i]);
swap(a[++ans], a[high-1]); // 最后的效果就是a[low,...,ans]都是 <=a[high-1]的, 其中a[ans] = a[high-1], 而a[ans]后面的都 <a[ans],即a[high-1]放在 ans
return ans;
}

int nth_element(int *a,int low, int high, int index) // 寻找a[low,...,high) 中从小到大排名index的元素
{
int ans = pivot(a, low,high);
if (ans==index) return a[index];
return index<ans? nth_element(a,low,ans, index):nth_element(a, ans+1, high, index);
}

int main()
{
#ifdef LOCAL
freopen("d:\\data.in","r",stdin);
freopen("d:\\my.out", "w", stdout);
#endif
srand((int)time(0));
const int size = 10;
int a[size+1] = {-1, 60, 59, 69, 21, 93, 83, 52, 36, 53, 94}; // 0 不用于存储数据
printf("从小到大排名第10的元素是%d\n", nth_element(a, 1, size+1, 10));
return 0;
}

其中pivot函数的作用就是将全部小于数组最后元素的元素换到a[low, …,ans-1]中去.

算法的复杂度是nlogn