Prim算法之堆优化

缘起

【1】和【2】中我们指出了prim算法是可以使用【1】中介绍的堆结构(完全二叉树)进行优化的. 性能可以到达nlogn.

分析

其实优化点就是在每次选取入伙MST的顶点的时候, 不就是要选择最短的么? 于是我们只需要维护一个小根堆即可. 每次交换堆顶和堆数组末尾元素, 然后n–再siftdown操作即可. 每次操作的复杂度由原先年的n降低为logn. 于是整体复杂度由n^2降低为nlogn. 而之前在【1】中我们也指出了,明白堆的实现原理即可, 实际写算法的时候我们还是坚持使用c++的stl或者java的util包下面的类进行堆结构的实现. 但是马上就会打脸. 因为stl的优先队列没有暴露sift接口. 这样的话, 我们无法在堆中元素的distance2U数值发生变化之后及时调整堆结构. 唯一的笨办法就是把pq清空, 再重新倒入pq(即重新建堆).可是一旦这样的话, 堆优化就名不副实了. 复杂度依旧是 O(n^2)的. 话虽如此, 但是我们依旧将这种代码贴上来 , 就当是学习c++的 优先队列的 STL了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include "stdafx.h"
#include <iostream>
#include <queue>
#pragma warning(disable:4996)

#define LOCAL
using namespace std;
const int MAX_NUM = 20;
const int INF = 0x3f3f3f3f;


struct Graph
{
int n;
int adj[MAX_NUM][MAX_NUM];
bool isIn[MAX_NUM];
int neighbor[MAX_NUM];
// 已经不需要distance数组了, 都维护在堆里面了
struct HeapNode // 定义堆中的节点
{
int distance2U,index; // 到MST中的距离以及顶点的索引

bool operator <(HeapNode node)
{
return distance2U > node.distance2U; // 即本节点到MST的距离大于node到MST的距离, 则本节点的优先级就比node小
}
};

struct Cmp
{
bool operator()(const HeapNode& a,const HeapNode &b) // 注意, 为了效率, 入参是引用, 为了鲁棒性, 这里加了const限制,该运算符就是重载 < , a<b 表示a的优先级低于b, 即本运算符返回true的话, 表示a的优先级低于b, 而优先队列pop出来的都是当前队列中优先级最高的
{
if (a.distance2U > b.distance2U) return true;
else if(a.distance2U == b.distance2U && a.index > b.index) return true; // 在距离相等的情况下,索引大的, 优先级低, 要排后面去
else return false;
}
};

HeapNode nodes[MAX_NUM];

priority_queue<HeapNode, vector<HeapNode>,Cmp> pq; // 类型为HeapNode的优先队列, 默认使用vector装载

Graph()
{
memset(adj, 0x3f, sizeof(adj));
memset(isIn, 0,sizeof(isIn));
isIn[1] = true;
puts("输入顶点的个数");
scanf("%d", &n);
puts("输入各边及其权重");
int x,y,w;
while(~scanf("%d%d%d", &x, &y, &w)) adj[x][y] = adj[y][x] = w;
for (int i = 2; i<=n; i++) neighbor[i] = 1;
}

int prim()
{
for (int i = 2; i<=n; i++)
{
nodes[i].index = i;
nodes[i].distance2U = adj[1][i];
pq.push(nodes[i]); // 初始化优先队列
}
int cost = 0;
int m = n-1;
while(m--)
{
HeapNode top = pq.top();
pq.pop(); // 这两行堆操作的代码将原本线性复杂度转换为了logn复杂度
isIn[top.index] = true;
cost += top.distance2U;
for (int i = 2; i<=n;i++)
{
if (!isIn[i] && nodes[i].distance2U > adj[i][top.index])
{
nodes[i].distance2U = adj[i][top.index];
neighbor[i] = top.index;
}
}
while(!pq.empty()) // 注意, 这里
{
pq.pop();
}
for(int i = 2; i<=n; i++)
{
if (!isIn[i])
{
pq.push(nodes[i]);
}
}
}
puts("MST如下");
for (int i = 2; i<=n; i++) printf("(%d,%d)\n", i, neighbor[i]);
return cost;
}
};

int main()
{
#ifdef LOCAL
freopen("d:\\data.in","r",stdin);
#endif
Graph g;
printf("MST总花销为%d\n",g.prim());
return 0;
}
/*
测试数据
6
3 1 1
3 2 5
3 4 5
3 5 6
3 6 4
1 2 6
2 5 3
5 6 6
6 4 2
1 4 5
*/

读者阅读了上面的代码之后可能会向我们建议——pq中装 HeapNode 不行吗? 为什么要装载HeapNode呢? 就是因为你装的是HeapNode, 而不是HeapNode, 所以你更新HeapNode之后, 堆中的数据并没有变化. 但是我想说的是, 就算变化了又怎么样? 优先队列并没有暴露siftdown接口, 你无法重新调整堆. 所以还是不行. 所以其实每次这样重新生成堆的做法其实并没有真正的堆优化. 真正的堆优化应该是一直存在一个堆在那里.

下面使用自己原生写的堆,则可以自由调整堆了.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#include "stdafx.h"
#include <iostream>
#include <queue>
#pragma warning(disable:4996)

#define LOCAL
using namespace std;
const int MAX_NUM = 20;
const int INF = 0x3f3f3f3f;


struct Graph
{
int n;
int adj[MAX_NUM][MAX_NUM];
bool isIn[MAX_NUM];
int neighbor[MAX_NUM];
// 已经不需要distance数组了, 都维护在堆里面了
struct HeapNode // 定义堆中的节点
{
int distance2U,index; // 到MST中的距离以及顶点的索引

bool operator <(HeapNode node)
{
return distance2U > node.distance2U; // 即本节点到MST的距离大于node到MST的距离, 则本节点的优先级就比node小
}
};

struct Cmp
{
bool operator()(const HeapNode& a,const HeapNode &b) // 注意, 为了效率, 入参是引用, 为了鲁棒性, 这里加了const限制,该运算符就是重载 < , a<b 表示a的优先级低于b, 即本运算符返回true的话, 表示a的优先级低于b
{
if (a.distance2U > b.distance2U) return true;
else if(a.distance2U == b.distance2U && a.index > b.index) return true; // 在距离相等的情况下,索引大的, 优先级低, 要排后面去
else return false;
}
};

HeapNode nodes[MAX_NUM]; // 堆数组
int m; // 堆数组中元素的个数 即 nodes[1,..,m]就是堆数组中元素

void siftdown(int index)
{
bool isok = false;
while(!isok && 2*index<=m)
{
int tmp = index;
if (nodes[index] < nodes[2*index]) // 因为我们要建立的是大根堆(即每次拿出优先级最高的那个, 而优先级高的含义是distance2U小)
{
tmp = 2*index;
}
if (2*index+1<=m && nodes[tmp] < nodes[2*index+1])
{
swap(nodes[index], nodes[2*index+1]);
index = index*2+1;
}
else if (tmp!=index)
{
swap(nodes[index], nodes[2*index]);
index *=2;
}
else
{
isok = true;
}
}
}

void siftup(int index) // 向上滚动操作 时间复杂度是O(logn) 这个操作原本用于建堆, 但是需要for(int i = 1;i<=n;i++)siftup(i) 复杂度是O(nlogn), 没有使用siftdown建堆复杂度低. 但是复杂度和siftdown一样, 而且代码简洁, 所以用于调整堆是不错的.
{
bool isok = false; // 是否已经不需要滚动了
while(!isok && index>1)
{
if (nodes[index/2] < nodes[index]) // 如果父节点的距离>子节点的距离, 因为要搞的是大根堆(大的定义是优先级高,而优先级高的含义就是距离小). 所以要交换
{
swap(nodes[index/2],nodes[index]);
index /=2;
}
else
{
isok = true;
}
}
}

Graph()
{
memset(adj, 0x3f, sizeof(adj));
memset(isIn, 0,sizeof(isIn));
isIn[1] = true;
puts("输入顶点的个数");
scanf("%d", &n);
puts("输入各边及其权重");
int x,y,w;
while(~scanf("%d%d%d", &x, &y, &w)) adj[x][y] = adj[y][x] = w;
for (int i = 2; i<=n; i++) neighbor[i] = 1;
}

int prim()
{
int cost = 0;
m = n-1;
for (int i = 1; i<=m;i++)
{
nodes[i].distance2U = adj[1][i+1];
nodes[i].index = i+1;
}
for (int i = m/2;i;i--)
{
siftdown(i);
} // 初始化堆,使用siftdown算法的复杂度低


while(m)
{
HeapNode top = nodes[1];
swap(nodes[1], nodes[m]);
m--;
isIn[top.index] = true;
cost += top.distance2U;
for (int i = 1; i<=m;i++) // 别急着更新堆, 因为马上根上的节点的数值都会变化的(这个for就是更新),所以更新完再进行堆结构的调整
{
if (nodes[i].distance2U > adj[nodes[i].index][top.index]) // 注意, 堆数组中所有的节点一定是没有进入MST的
{
nodes[i].distance2U = adj[nodes[i].index][top.index];
neighbor[nodes[i].index] = top.index;
}
}
// 更新完堆数组中各个节点的数值之后开始调整堆结构,注意,调整的顺序是i从第到高,这一点和【1】中siftdown为什么要从后往前调用是一个理由.
for (int i = 1; i<=m;i++)
{
siftup(i);
}
}
puts("MST如下");
for (int i = 2; i<=n; i++) printf("(%d,%d)\n", i, neighbor[i]);
return cost;
}
};

int main()
{
#ifdef LOCAL
freopen("d:\\data.in","r",stdin);
#endif
Graph g;
printf("MST总花销为%d\n",g.prim());
return 0;
}
/*
测试数据
6
3 1 1
3 2 5
3 4 5
3 5 6
3 6 4
1 2 6
2 5 3
5 6 6
6 4 2
1 4 5
*/

参考

【1】https://yfsyfs.github.io/2019/05/25/%E5%A0%86%E6%8E%92%E5%BA%8F/

【2】https://yfsyfs.github.io/2019/05/25/%E6%97%A0%E5%90%91%E8%BF%9E%E9%80%9A%E5%9B%BE%E7%9A%84%E6%9C%80%E5%B0%8F%E7%94%9F%E6%88%90%E6%A0%91%E4%B9%8BPrim%E7%AE%97%E6%B3%95/