Python实现快排及其可视化
最近装了个Anaconda
,准备学习一下数据可视化。本着三天打鱼两天装死的心态,重新抱起崭新的算法书,认真学起了快排算法。学完后用Python
实现了一遍基本的快排,然后使用matplotlib
进行动态绘图,最后使用imageio
生成GIF图片。谨以此文以记之!
快排基本原理
快排采用和归并排序相同的分而治之的思想,将待排序数组分成左右两个子数组,对两部分子数组独立排序。当子数组均有序时,整个数组也就有序了。
排序步骤如下:
- 将原始数组
data
随机打乱,以消除对输入的依赖(本步可选) - 选择数组的首个元素
data[0]
作为切分元素v
- 切分数组
- 从左往右找到第一个大于切分元素
v
的元素data[i]
- 从右到左找到第一个小于切分元素
v
的元素data[j]
- 交换
data[i]
与data[j]
- 重复以上三步直到
i>=j
- 交换
data[j]
与切分元素data[0]
- 从左往右找到第一个大于切分元素
- 递归调用,对切分后的左侧子数组进行排序
- 递归调用,对切分后的右侧子数组进行排序
文字性的描述总是那么苍白无力,但还好也能说明一些问题。可以看出,快排的关键在于切分,切分后的数组应该满足:
- 切分元素的位置(设为
j
)已经固定 data[lo]
到data[j-1]
区间内的元素均不大于切分元素data[j]
data[j+1]
到data[hi]
区间内的元素均不小于切分元素data[j]
其中data[lo]
代表数组或子数组的首个元素,data[hi]
代表数组或子数组的末尾元素。
简单点说,就是先找一个参考点,把小于这个参考点的元素都扔到它的左边,大于这个参考点的数都扔到它的右边。这样一来,参考点的位置就固定了,然后对左边的数据和右边的数据各自再递归的扔几遍,等所有子数组都扔完了,整个数组也就有序了。
不过需要注意的是,扔的时候不是随便扔,是把从左往右找到的第一个大于参考点的值和从右往左找到的第一个小于参考点的值进行替换。
基本实现
Talk is cheap, show me the code
def sort(data):
__sort(data, 0, len(data) - 1)
def __sort(data, lo, hi):
if lo >= hi:
return
key = __partition(data, lo, hi)
__sort(data, lo, key - 1)
__sort(data, key + 1, hi)
def __swap(data, lo, hi):
data[lo], data[hi] = data[hi], data[lo]
def __partition(data, lo, hi):
'''partition array'''
i = lo
j = hi
v = data[lo] # slicing element
while True:
# find one element that larger than v scan from left to right(→)
i += 1
while data[i] < v:
if i == hi:
break
i += 1
# find one element that smaller than v scan from right to left(←)
while v < data[j]:
if j == lo:
break
j -= 1
if i >= j:
break
__swap(data, i, j)
__swap(data, lo, j)
return j
以上便是参考Algorithms
书上java
代码的Python
实现。下面是个使用示例:
import random
def main():
data = [_ for _ in range(20)]
random.shuffle(data)
print(data)
sort(data)
print(data)
if __name__ == '__main__':
main()
执行结果如下:
$ python quick.py
[4, 9, 1, 13, 18, 5, 6, 14, 2, 16, 7, 12, 15, 8, 11, 17, 0, 19, 10, 3]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
快排优化
快排有很多优化算法,目前我只习得一种最简单的,可以将切分函数两个内部while
循环中的if
语句去掉
# first one
if i == hi:
break
# second one
if j == lo:
break
这两个判断都是为了防止访问数组越界而设,其实第二个是完全没有必要加的,因为lo
对应的就是切分元素本身,自己肯定不会小于自己,所以这个判断完全是多余的;对于第一个,想要去掉的话,只要保证数组最后一个元素最大即可,实现上只要在执行排序函数之前将最大值换至最后即可。
__swap(data, data.index(max(data)), len(data) - 1)
本文后续的code
会将第二个判断去掉,但第一个的还保留着,毕竟把最大值直接挪到最后总感觉怪怪的,在可视化的时候也会牺牲一点随机性。
面向对象编程
为了方便代码的阅读和管理,我将快排代码封装成QuickSort
类,同时加入变量swap_times
用于记录总的数据交换次数。
class QuickSort(object):
'''Quick sort algorithm'''
def sort(self, data):
self.swap_times = 0
# set the largest element to the end
# self.__swap(data, data.index(max(data)), len(data) - 1)
self.__sort(data, 0, len(data) - 1)
return self.swap_times
def __swap(self, data, lo, hi):
data[lo], data[hi] = data[hi], data[lo]
self.swap_times += 1
def __sort(self, data, lo, hi):
if lo >= hi:
return
key = self.__partition(data, lo, hi)
self.__sort(data, lo, key - 1)
self.__sort(data, key + 1, hi)
def __partition(self, data, lo, hi):
'''partition array'''
i = lo
j = hi
v = data[lo] # slicing element
while True:
# find one element that larger than v scan from left to right(→)
i += 1
while data[i] < v:
# below judge can dropped if the end element is the largest
if i == hi:
break
i += 1
# find one element that smaller than v scan from right to left(←)
while v < data[j]:
j -= 1
if i >= j:
break
self.__swap(data, i, j)
self.__swap(data, lo, j)
return j
打印数据交换记录
为了了解排序过程中数据交换,可以在__swap
函数中打印每一次交换后的数组。
def __init__(self, debug=False, save_fig=False):
self.debug = debug
def __swap(self, data, lo, hi):
data[lo], data[hi] = data[hi], data[lo]
self.swap_times += 1
if self.debug:
print('{0} swap({1}, {2})'.format(data, lo, hi))
示例:
#main.py
def main():
data = []
random.seed(time.time())
data = [_ for _ in range(20)]
random.shuffle(data)
qs = QuickSort(debug=True)
swap_times, = qs.sort(data)
if __name__ == '__main__':
main()
➜ algorithm git:(master) ✗ ./main.py
[14, 3, 11, 10, 4, 1, 2, 12, 18, 17, 7, 8, 13, 15, 0, 9, 16, 6, 5, 19] swap(4, 19)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 17, 7, 8, 13, 15, 0, 9, 16, 6, 18, 19] swap(8, 18)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 15, 0, 9, 16, 17, 18, 19] swap(9, 17)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 0, 15, 16, 17, 18, 19] swap(13, 15)
[0, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(0, 14)
[0, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(0, 0)
[0, 3, 2, 10, 4, 1, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(2, 6)
[0, 3, 2, 1, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(3, 5)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(1, 3)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(1, 1)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(4, 4)
[0, 1, 2, 3, 4, 10, 9, 12, 5, 6, 7, 8, 13, 11, 14, 15, 16, 17, 18, 19] swap(6, 13)
[0, 1, 2, 3, 4, 10, 9, 8, 5, 6, 7, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(7, 11)
[0, 1, 2, 3, 4, 7, 9, 8, 5, 6, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 10)
[0, 1, 2, 3, 4, 7, 6, 8, 5, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(6, 9)
[0, 1, 2, 3, 4, 7, 6, 5, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(7, 8)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 7)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 5)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(8, 8)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 16, 17, 18, 19] swap(12, 13)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(11, 12)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(15, 15)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(16, 16)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(17, 17)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(18, 18)
根据打印结果可以逐步分析快排算法的执行过程,明确知晓每一次交换的数据。
数据可视化
打印log固然可以分析算法的执行流程,但是不够直观,所以想着用可视化工具进一步处理,以动态图片形式显示快排过程。为此,只需在交换数据的函数__swap
中使用matplotlib
的柱状图绘制交换完成后的数组即可。
import matplotlib.pyplot as plt
...
class QuickSort(object):
def __init__(self, debug=False, save_fig=False):
...
self.save_fig = save_fig
self.fig, self.ax = plt.subplots()
# open interactive mode of matplot
plt.ion()
if self.save_fig:
self.path = './images/{0}'.format(time.strftime('%Y%m%d_%H%M%S'))
os.makedirs(self.path)
def __swap(self, data, lo, hi):
...
self.__plot_figure(data, lo, hi, show_swap=True)
def __plot_figure(self, data, lo=0, hi=0, show_swap=False):
'''plot and save figure'''
self.ax.clear()
self.ax.set_title('data quicksort')
self.ax.bar(range(len(data)), data, label='data')
if show_swap:
self.ax.bar([lo, hi], [data[lo], data[hi]], color='red', label='swap')
plt.legend()
plt.pause(0.001)
if self.save_fig:
plt.savefig('{0}/{1}.png'.format(self.path, self.swap_times))
需要注意的几点是:
matplotlib
具有两种绘图模式,阻塞(block
)模式和交互(interactive
)模式,阻塞模式必须等待当前绘图窗口关闭方才执行后续程序,而交互模式则无需等待。为了动态显示排序过程,自然选择交互模式,所以初始化时调用ion()
函数打开交互模式- 相比于其它样式的图表,使用柱状图
bar
能够更直观显示数据大小及变化过程 - 在每次重绘图表时,需要清空原有图表
- 必须调用
pause
函数予以等待,否则可能出现无法显示图表的情况,等待时长自定 - 为了突出显示每次交换的两个数据,可以使用红色图表单独绘制交换数据
- 使用
matplotlib
函数库中的savefig
可以将图表为至本地图片文件,为后续生成gif
图片做准备
生成GIF动图
有了前面保存好的图片,使用imageio
库的append
函数及mimsave
即可生成gif
图片。图片间隔时间由mimsave
函数的duration
参数决定。
# main.py
import imageio
from quick_sort import QuickSort
import os
from os.path import join
def save_gif(path, gif_name):
if not os.path.exists(path) or len(os.listdir(path))==0:
return
images = []
for file_name in range(len(os.listdir(path))):
file_path = join(path, '{}.png'.format(file_name))
images.append(imageio.imread(file_path))
imageio.mimsave(join(path, gif_name), images, 'GIF', duration=0.2)
至此,便完成了快排的算法实现及其可视化。
完整源码
代码已上传至github Python-demos algorithm
目录
- quick_sort.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import time
import os
class QuickSort(object):
'''Quick sort algorithm'''
def __init__(self, debug=False, save_fig=False):
self.debug = debug
self.save_fig = save_fig
self.fig, self.ax = plt.subplots()
plt.ion()
if self.save_fig:
self.path = './images/{0}'.format(time.strftime('%Y%m%d_%H%M%S'))
os.makedirs(self.path)
def sort(self, data):
self.swap_times = 0
self.__plot_figure(data)
# set the largest element to the end
# self.__swap(data, data.index(max(data)), len(data) - 1)
self.__sort(data, 0, len(data) - 1)
return self.swap_times, self.path
def __swap(self, data, lo, hi):
data[lo], data[hi] = data[hi], data[lo]
self.swap_times += 1
if self.debug:
print('\t{0} swap({1}, {2})'.format(data, lo, hi))
self.__plot_figure(data, lo, hi, show_swap=True)
def __plot_figure(self, data, lo=0, hi=0, show_swap=False):
'''plot and save figure'''
self.ax.clear()
self.ax.set_title('data quicksort')
self.ax.bar(range(len(data)), data, label='data')
if show_swap:
self.ax.bar([lo, hi], [data[lo], data[hi]], color='red', label='swap')
plt.legend()
plt.pause(0.001)
if self.save_fig:
plt.savefig('{0}/{1}.png'.format(self.path, self.swap_times))
def __sort(self, data, lo, hi):
if lo >= hi:
return
key = self.__partition(data, lo, hi)
self.__sort(data, lo, key - 1)
self.__sort(data, key + 1, hi)
def __partition(self, data, lo, hi):
'''partition array'''
i = lo
j = hi
v = data[lo] # slicing element
while True:
# find one element that larger than v scan from left to right(→)
i += 1
while data[i] < v:
# below judge can dropped if the end element is the largest
if i == hi:
break
i += 1
# find one element that smaller than v scan from right to left(←)
while v < data[j]:
j -= 1
if i >= j:
break
self.__swap(data, i, j)
self.__swap(data, lo, j)
return j
- main.py
#!/bin/env python
# -*- encoding: utf-8 -*-
import time
import random
import imageio
from quick_sort import QuickSort
import os
from os.path import join
def save_gif(path, gif_name):
if not os.path.exists(path) or len(os.listdir(path))==0:
return
images = []
for file_name in range(len(os.listdir(path))):
file_path = join(path, '{}.png'.format(file_name))
images.append(imageio.imread(file_path))
imageio.mimsave(join(path, gif_name), images, 'GIF', duration=0.2)
def main():
data = []
random.seed(time.time())
random.shuffle(data)
print('source: {0}'.format(data))
start = time.time()
qs = QuickSort(debug=False, save_fig=True)
swap_times, fig_path = qs.sort(data)
save_gif(fig_path, 'quick_sort.gif')
stop = time.time()
print('result: {0}\n'.format(data))
print('----------------------------------')
print('swap times: {0}'.format(swap_times))
print('spend time: {0}s'.format(stop - start))
print('image path: {0}'.format(fig_path))
print('----------------------------------')
if __name__ == '__main__':
main()
版权声明:本博客所有文章除特殊声明外,均采用 CC BY-NC 4.0 许可协议。转载请注明出处 litreily的博客!