我用 Python 实现了冒泡排序、选择排序、插入排序、归并排序、快速排序。然后简单讲了讲快速排序的优化,我们可以通过小数组采用插入排序来减少递归的开销;对于有一定顺序的数组,我采用三数取中来提高性能;对于包含大量重复数的数组,我用了三路快速排序来提高性能。 最后,我把这些排序算法应用在随机数组、升序数组、降序数组、包含大量重复数的数组上,比较了一下它们的耗时。
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp
def BubbleSort(nums):
for i in range(len(nums)-1):
for j in range(len(nums)-i-1):
if nums[j] > nums[j+1]:
exchange(nums,j,j+1)
首先,找到数组中最小的那个元素,然后将它和数组的第一个元素交换位置(如果第一个元素就是最小元素那么它就和自己交换)。然后在剩下的元素中找到最小的元素,将它与数组的第二个元素交换位置。如此往复,直到将整个数组排序。这种方法叫做选择排序,因为它在不断地选择剩余元素之中的最小者。
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp
def InsertSort(nums):
for i in range(len(nums)-1):
j = i + 1
while i >= 0 and nums[i] > nums[j]:
exchange(nums,i,j)
j -= 1
i -= 1
通常人们整理桥牌的方法是一张一张的来,将每一张牌插入到其他已经有序的牌中的适当位置。在计算机的实现中,为了给要插入的元素腾出空间,我们需要将其余所有元素在插入之前都向右移动一位。这种算法叫做插入排序。
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp
def InsertSort(nums):
for i in range(len(nums)-1):
j = i + 1
while i >= 0 and nums[i] > nums[j]:
exchange(nums,i,j)
j -= 1
i -= 1
归并排序体现的是一种分治思想(Divide and conquer),下面是其排序的步骤:
具体步骤如下:
def merge(a,aux,low,mid,high):
i = low
j = mid+1
k = 0
for k in range(low,high+1):
if i > mid:
a[k] = aux[j]
j += 1
elif j > high:
a[k] = aux[i]
i += 1
else:
if aux[i] > aux[j]:
a[k] = aux[j]
j += 1
else:
a[k] = aux[i]
i += 1
我们要对数组a[low..high]
进行排序,先将它分为a[low..mid]
和a[mid+1..high]
两部分,分别递归调用将它们单独排序,最后将有序的子数组归并为最终的排序结果。
def sort(a,aux,low,high):
# 退出条件
if low >= high:
return
mid = (low + high) // 2
sort(a,aux,low,mid)
sort(a,aux,mid+1,high)
merge(a,aux,low,mid,high)
为了保证归并排序函数 MergeSort () 输入只有未排序的数组,这里调用前面的辅助函数 sort ():
def MergeSort(nums):
aux = nums.copy()
low = 0
high = len(nums)-1
sort(nums,aux,low,high)
return nums
快速排序是一种分治的排序算法。它将一个数组分成两个子数组,将两部分独立地排序。
分治策略指的是:将原问题分解为若干个规模更小但结构与原问题相似的子问题。递归地解这些子问题,然后将这些子问题的解组合为原问题的解。
下面是一个示例:
下面的代码短小利于理解,但是空间复杂度大,使用了三个列表解析式,而且每次选取进行比较时需要遍历整个序列。
def QuickSort(a):
if len(a) < 2:
return a
else:
pivot = a[0]
less_than_pivot = [x for x in a if x < pivot]
more_than_pivot = [x for x in a if x > pivot]
pivot_list = [x for x in a if x == pivot]
return QuickSort(less_than_pivot) + pivot_list + QuickSort(more_than_pivot)
1. 切分 ——partition ()
切分方法:先随意地取a[low]
作为切分元素(即那个将会被排定的元素),然后我们从数组的左端开始向右扫描直到找到一个大于等于它的元素,再从数组的右端开始向左扫描直到找到一个小于等于它的元素。这两个元素是没有排定的,因此我们交换它们的位置。如此继续,当两个指针相遇时,我们只需要将切分元素a[low]
和左子元素最右侧的元素a[j]
交换然后返回 j 即可。
def partition(a,low,high):
i = low # 循环内i=i+1
j = high + 1 # 循环内j=j-1
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
i += 1 # 保证i每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[i] < a[low] and i < high:
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
j -= 1 # 保证j每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[j] > a[low] and j > low:
j -= 1
# 如果两个指针交叉,说明已经排序完了
if i >= j:
break
exchange(a,i,j)
# 指针相遇后,j所在的元素小于low,进行互换
exchange(a,low,j)
return j
这里有个细节需要注意下,这个代码相比我最初的代码改变了:
def partition(a,low,high):
- i = low + 1
+ i = low # 循环内i=i+1
- j = high
+ j = high + 1 # 循环内j=j-1
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
+ i += 1 # 保证i每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[i] < a[low] and i < high:
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
+ j -= 1 # 保证j每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[j] > a[low] and j > low:
j -= 1
# 如果两个指针交叉,说明已经排序完了
if i >= j:
break
exchange(a,i,j)
# 指针相遇后,j所在的元素小于low,进行互换
exchange(a,low,j)
return j
如果没有这些代码,当碰到[2,2,2]
这样的情况时,i 和 j 一直不会改变,永远无法满足if i >= j
,然后函数就一直在while True
里边死循环。
2. sort()函数
快速排序递归地将子数组a[low..high]
排序,先用partition()
方法将a[j]
放到一个合适位置,然后再用递归调用将其他位置的元素排序。
def sort(a,low,high):
if low >= high:
return
j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
3. QuickSort () 函数
为了保证快速排序函数 QuickSort () 输入只有未排序的数组,这里调用前面的辅助函数 sort ():
def QuickSort(nums):
low = 0
high = len(nums)-1
sort(nums,low,high)
return nums
1. 优化小数组效率
对于规模很小的情况,快速排序的优势并不明显(可能没有优势),而递归型的算法还会带来额外的开销。于是对于这类情况可以选择非递归型的算法来替代。
那就有两个问题:多小的数组算小数组?替换的算法是什么?
通常这个阈值设定为 10,替换的算法一般是插入排序。
下面是 Python 实现,这里只需要在 sort () 函数中加一个数组大小判断即可:
CUTOFF = 10
def sort(a,low,high):
if low >= high:
return
# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return
j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
2. 合理选择 pivot
前面也讨论过,直接选择分区的第一个或最后一个元素做 pivot 是不合适的。对于已经排好序,或者接近排好序的情况,会进入最差情况,时间复杂度退化到 n^2。
pivot 选取的理想情况是:让分区中比 pivot 小的元素数量和比 pivot 大的元素数量差不多。较常用的做法是三数取中( median of three ),即从第一项、最后一项、中间一项中取中位数作为 pivot。当然这并不能完全避免最差情况的发生。所以很多时候会采取更小心、更严谨的 pivot 选择方案(对于大数组特别重要)。比如先把大数组平均切分成左中右三个部分,每个部分用三数取中得到一个中位数,再从得到的三个中位数中找出中位数。
CUTOFF = 10
def get_median(nums,low,high):
# 计算数组中间的元素的下标
mid = (low + high) // 2
# 目标: arr[mid] <= arr[high]
if nums[mid] > nums[high]:
exchange(nums,mid,high)
# 目标: arr[low] <= arr[high]
if nums[low] > nums[high]:
exchange(nums,low,high)
# 目标: arr[low] >= arr[mid]
if nums[low] < nums[mid]:
exchange(nums,low,mid)
# 此时,arr[mid] <= arr[low] <= arr[high],low的位置上保存这三个位置中间的值
return nums[low]
def sort(a,low,high):
if low >= high:
return
# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return
# 三数取中(median of three),low的位置上保存这三个位置中间的值
_ = get_median(a,low,high)
j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
3. 处理重复元素问题
当一个数组里的元素全部一样大(或者存在大量相同元素)会令快速排序进入最差情况,因为不管怎么选 pivot,都会使分区结果一边很大一边很小。
为了解决这个问题,我们需要修改分区过程,思路跟上面说的两路分区(基本的快排)类似,只是现在我们需要小于 pivot、等于 pivot、大于 pivot 三个分区。
举个例子,待分割序列:6 4 6 7 1 6 7 6 8 6
,其中pivot=6
:
1 4 6 6 7 6 7 6 8 6
1 4 6
和 7 6 7 6 8 6
1 4 6 6 6 6 6 7 8 7
1 4
和 7 8 7
经过对比,我们可以看出,在一次划分后,把与 key 相等的元素聚在一起,能减少迭代次数,效率会提高不少。
具体过程:
如下图,我们可以设置四个游标,左端 p、i,右端 j、q。i、j 的作用跟之前两路划分时候的左右游标相同,就是从两端向中间遍历序列,并将遍历到的元素与 pivot 比较,如果等于 pivot,则移到两端(i 对应的元素移到左端,j 对应的元素移到右端。移动的方式就是拿此元素和 a 或 d 对应的元素进行交换,所以 p 和 q 的作用就是记录等于 pivot 的元素移动过后的边界),反之,如果大于或小于 pivot,还按照之前两路划分的方式进行移动。这样一来,中间部分就和两路划分相同,两头是等于 pivot 的部分,我们只需要将这两部分移动到中间即可。
def partition(a,low,high):
p = low + 1
i = low + 1
j = high
q = high
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
while a[i] <= a[low] and i < high:
# 与pivot相等的元素将其交换到p所在的位置
if a[i] == a[low]:
exchange(a,p,i)
p += 1
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
while a[j] >= a[low] and j > low:
# 与pivot相等的元素将其交换到q所在的位置
if a[j] == a[low]:
exchange(a,j,q)
q -= 1
j -= 1
# 如果两个指针交叉,说明已经排序完了
if i >= j:
break
exchange(a,i,j)
# 因为工作指针i指向的是当前需要处理元素的下一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
i -= 1
p -= 1
while p >= low:
exchange(a, i, p)
i -= 1
p -= 1
# 因为工作指针j指向的是当前需要处理元素的上一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
j += 1
q += 1
while q <= high:
exchange(a, q, j)
j += 1
q += 1
return i,j
下面是 sort () 函数,这里我只写了修改的部分:
def sort(a,low,high):
# ...
i,j = partition(a,low,high)
sort(a,low,i)
sort(a,j,high)
下面是经过优化的快速排序代码:
CUTOFF = 10
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp
def InsertSort(nums):
for i in range(len(nums)-1):
j = i + 1
while i >= 0 and nums[i] > nums[j]:
exchange(nums,i,j)
j -= 1
i -= 1
def partition(a,low,high):
p = low + 1
i = low + 1
j = high
q = high
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
while a[i] <= a[low] and i < high:
# 与pivot相等的元素将其交换到p所在的位置
if a[i] == a[low]:
exchange(a,p,i)
p += 1
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
while a[j] >= a[low] and j > low:
# 与pivot相等的元素将其交换到q所在的位置
if a[j] == a[low]:
exchange(a,j,q)
q -= 1
j -= 1
# 如果两个指针交叉,说明已经排序完了
if i >= j:
break
exchange(a,i,j)
# 因为工作指针i指向的是当前需要处理元素的下一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
i -= 1
p -= 1
while p >= low:
exchange(a, i, p)
i -= 1
p -= 1
# 因为工作指针j指向的是当前需要处理元素的上一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
j += 1
q += 1
while q <= high:
exchange(a, q, j)
j += 1
q += 1
return i,j
def get_median(nums,low,high):
# 计算数组中间的元素的下标
mid = (low + high) // 2
# 目标: arr[mid] <= arr[high]
if nums[mid] > nums[high]:
exchange(nums,mid,high)
# 目标: arr[low] <= arr[high]
if nums[low] > nums[high]:
exchange(nums,low,high)
# 目标: arr[low] >= arr[mid]
if nums[low] < nums[mid]:
exchange(nums,low,mid)
# 此时,arr[mid] <= arr[low] <= arr[high],low的位置上保存这三个位置中间的值
return nums[low]
def sort(a,low,high):
if low >= high:
return
# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return
# 三数取中(median of three),low的位置上保存这三个位置中间的值
_ = get_median(a,low,high)
i,j = partition(a,low,high)
sort(a,low,i)
sort(a,j,high)
def QuickSort3Ways(nums):
low = 0
high = len(nums)-1
sort(nums,low,high)
return nums
nums = [4,5,6,1,2,3,3,3,1,2]
print(QuickSort(nums))
快速排序和归并排序是互补的:
不同数据集可以用同一个计时函数,具体如下:
import time
# 计时函数
def count_time(a,sortname):
time_start = time.time()
if sortname == 'BubbleSort':
BubbleSort(a)
if sortname == 'SelectSort':
SelectSort(a)
if sortname == 'InsertSort':
InsertSort(a)
if sortname == 'MergeSort':
MergeSort(a)
if sortname == 'QuickSort':
QuickSort(a)
if sortname == 'QuickSort3Ways':
QuickSort3Ways(a)
time_end = time.time()
return (time_end - time_start)
随机数据生成器:
import random
def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
totalTime += count_time(a,sortname)
return totalTime
这里我们生成一个长度为 5000 的数组,然后重复测试 10 次,最后计算各个排序算法用时:
length = 5000
numberOfArrays = 10
print("BubbleSort's total time:")
print(timeRandomInput('BubbleSort',length,numberOfArrays))
print("SelectSort's total time:")
print(timeRandomInput('SelectSort',length,numberOfArrays))
print("InsertSort's total time:")
print(timeRandomInput('InsertSort',length,numberOfArrays))
print("MergeSort's total time:")
print(timeRandomInput('MergeSort',length,numberOfArrays))
print("QuickSort's total time:")
print(timeRandomInput('QuickSort',length,numberOfArrays))
print("QuickSort3Ways's total time:")
print(timeRandomInput('QuickSort3Ways',length,numberOfArrays))
BubbleSort's total time:
30.023681640625
SelectSort's total time:
11.03202223777771
InsertSort's total time:
24.185371160507202
MergeSort's total time:
0.1900651454925537
QuickSort's total time:
0.1554875373840332
QuickSort3Ways's total time:
0.19011521339416504
这里我们看下这些排序算法在降序数据集下的表现,首先改变数据生成函数:
import random
def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.sort(reverse = True)
totalTime += count_time(a,sortname)
return totalTime
这里如果生成一个长度为 10000 的数组,快速排序会出现RecursionError: maximum recursion depth exceeded in comparison
错误。这个因为 Python 中默认的最大递归深度是 989。解决方案:手动设置递归调用深度,具体代码如下:
import random
+import sys
+sys.setrecursionlimit(1000000)
def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
a.sort(reverse = True)
totalTime += count_time(a,sortname)
return totalTime
数组大小改变为 5000,重复 10 次,下面是测试结果:
BubbleSort's total time:
45.00776267051697
SelectSort's total time:
11.393858909606934
InsertSort's total time:
48.275355100631714
MergeSort's total time:
0.18087530136108398
QuickSort's total time:
14.895536661148071
QuickSort3Ways's total time:
0.10853052139282227
这里我们看下这些排序算法在升序数据集下的表现,首先改变数据生成函数:
import random
import sys
sys.setrecursionlimit(1000000)
def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.sort(reverse = False)
totalTime += count_time(a,sortname)
return totalTime
同样的,这里数组大小为 5000,重复 10 次,下面是测试结果:
BubbleSort's total time:
14.935291051864624
SelectSort's total time:
11.371372699737549
InsertSort's total time:
0.008459329605102539
MergeSort's total time:
0.15901756286621094
QuickSort's total time:
16.011647939682007
QuickSort3Ways's total time:
0.10053849220275879
这里我们看下这些排序算法在含有大量重复数的数据集下的表现,首先改变数据生成函数:
import random
import sys
sys.setrecursionlimit(1000000)
def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
- a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.append(random.randint(999990, 1000000)) # 测试数据范围
totalTime += count_time(a,sortname)
return totalTime
同样的,这里数组大小为 5000,重复 10 次,下面是测试结果:
BubbleSort's total time:
28.813392877578735
SelectSort's total time:
11.362754821777344
InsertSort's total time:
22.454782247543335
MergeSort's total time:
0.1563563346862793
QuickSort's total time:
0.15424251556396484
QuickSort3Ways's total time:
0.08862972259521484
BubbleSort | SelectSort | InsertSort | MergeSort | QuickSort | QuickSort3Ways | |
---|---|---|---|---|---|---|
随机数据集 | 30.023 | 11.032 | 24.185 | 0.190 | 0.155 | 0.190 |
升序数据集 | 14.935 | 11.371 | 0.008 | 0.159 | 16.011 | 0.100 |
降序数据集 | 45.007 | 11.393 | 48.275 | 0.180 | 14.895 | 0.108 |
大量重复数的数据集 | 28.813 | 11.362 | 22.454 | 0.156 | 0.154 | 0.088 |
经过优化后的三路快速排序在升序、降序、包含大量重复数的情况下表现均非常优异。