Python实现快排及其可视化

最近装了个Anaconda,准备学习一下数据可视化。本着三天打鱼两天装死的心态,重新抱起崭新的算法书,认真学起了快排算法。学完后用Python实现了一遍基本的快排,然后使用matplotlib进行动态绘图,最后使用imageio生成GIF图片。谨以此文以记之!

快排基本原理

快排采用和归并排序相同的分而治之的思想,将待排序数组分成左右两个子数组,对两部分子数组独立排序。当子数组均有序时,整个数组也就有序了。

排序步骤如下:

  1. 将原始数组data随机打乱,以消除对输入的依赖(本步可选)
  2. 选择数组的首个元素data[0]作为切分元素v
  3. 切分数组
    • 从左往右找到第一个大于切分元素v的元素data[i]
    • 从右到左找到第一个小于切分元素v的元素data[j]
    • 交换data[i]data[j]
    • 重复以上三步直到i>=j
    • 交换data[j]与切分元素data[0]
  4. 递归调用,对切分后的左侧子数组进行排序
  5. 递归调用,对切分后的右侧子数组进行排序

文字性的描述总是那么苍白无力,但还好也能说明一些问题。可以看出,快排的关键在于切分,切分后的数组应该满足:

  1. 切分元素的位置(设为j)已经固定
  2. data[lo]data[j-1]区间内的元素均不大于切分元素data[j]
  3. data[j+1]data[hi]区间内的元素均不小于切分元素data[j]

其中data[lo]代表数组或子数组的首个元素,data[hi]代表数组或子数组的末尾元素。

简单点说,就是先找一个参考点,把小于这个参考点的元素都扔到它的左边,大于这个参考点的数都扔到它的右边。这样一来,参考点的位置就固定了,然后对左边的数据和右边的数据各自再递归的扔几遍,等所有子数组都扔完了,整个数组也就有序了。

不过需要注意的是,扔的时候不是随便扔,是把从左往右找到的第一个大于参考点的值和从右往左找到的第一个小于参考点的值进行替换。

基本实现

Talk is cheap, show me the code

def sort(data):
    __sort(data, 0, len(data) - 1)

def __sort(data, lo, hi):
    if lo >= hi:
        return
    key = __partition(data, lo, hi)
    __sort(data, lo, key - 1)
    __sort(data, key + 1, hi)

def __swap(data, lo, hi):
    data[lo], data[hi] = data[hi], data[lo]

def __partition(data, lo, hi):
    '''partition array'''
    i = lo
    j = hi
    v = data[lo] # slicing element
    while True:
        # find one element that larger than v scan from left to right(→)
        i += 1
        while data[i] < v:
            if i == hi:
                break
            i += 1

        # find one element that smaller than v scan from right to left(←)
        while v < data[j]:
            if j == lo:
                break
            j -= 1

        if i >= j:
            break
        __swap(data, i, j)
    __swap(data, lo, j)
    return j

以上便是参考Algorithms书上java代码的Python实现。下面是个使用示例:

import random

def main():
    data = [_ for _ in range(20)]
    random.shuffle(data)
    print(data)
    sort(data)
    print(data)

if __name__ == '__main__':
    main()

执行结果如下:

$ python quick.py
[4, 9, 1, 13, 18, 5, 6, 14, 2, 16, 7, 12, 15, 8, 11, 17, 0, 19, 10, 3]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

快排优化

快排有很多优化算法,目前我只习得一种最简单的,可以将切分函数两个内部while循环中的if语句去掉

# first one
if i == hi:
    break

# second one
if j == lo:
    break

这两个判断都是为了防止访问数组越界而设,其实第二个是完全没有必要加的,因为lo对应的就是切分元素本身,自己肯定不会小于自己,所以这个判断完全是多余的;对于第一个,想要去掉的话,只要保证数组最后一个元素最大即可,实现上只要在执行排序函数之前将最大值换至最后即可。

__swap(data, data.index(max(data)), len(data) - 1)

本文后续的code会将第二个判断去掉,但第一个的还保留着,毕竟把最大值直接挪到最后总感觉怪怪的,在可视化的时候也会牺牲一点随机性。

面向对象编程

为了方便代码的阅读和管理,我将快排代码封装成QuickSort类,同时加入变量swap_times用于记录总的数据交换次数。

class QuickSort(object):
    '''Quick sort algorithm'''
    def sort(self, data):
        self.swap_times = 0
        # set the largest element to the end
        # self.__swap(data, data.index(max(data)), len(data) - 1)
        self.__sort(data, 0, len(data) - 1)
        return self.swap_times

    def __swap(self, data, lo, hi):
        data[lo], data[hi] = data[hi], data[lo]
        self.swap_times += 1

    def __sort(self, data, lo, hi):
        if lo >= hi:
            return
        key = self.__partition(data, lo, hi)
        self.__sort(data, lo, key - 1)
        self.__sort(data, key + 1, hi)

    def __partition(self, data, lo, hi):
        '''partition array'''
        i = lo
        j = hi
        v = data[lo] # slicing element
        while True:
            # find one element that larger than v scan from left to right(→)
            i += 1
            while data[i] < v:
                # below judge can dropped if the end element is the largest
                if i == hi:
                    break
                i += 1

            # find one element that smaller than v scan from right to left(←)
            while v < data[j]:
                j -= 1

            if i >= j:
                break
            self.__swap(data, i, j)
        self.__swap(data, lo, j)
        return j

打印数据交换记录

为了了解排序过程中数据交换,可以在__swap函数中打印每一次交换后的数组。

def __init__(self, debug=False, save_fig=False):
    self.debug = debug

def __swap(self, data, lo, hi):
    data[lo], data[hi] = data[hi], data[lo]
    self.swap_times += 1

    if self.debug:
        print('{0} swap({1}, {2})'.format(data, lo, hi))

示例:

#main.py
def main():
    data = []
    random.seed(time.time())
    data = [_ for _ in range(20)]
    random.shuffle(data)

    qs = QuickSort(debug=True)
    swap_times, = qs.sort(data)


if __name__ == '__main__':
    main()
➜  algorithm git:(master) ✗ ./main.py
[14, 3, 11, 10, 4, 1, 2, 12, 18, 17, 7, 8, 13, 15, 0, 9, 16, 6, 5, 19] swap(4, 19)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 17, 7, 8, 13, 15, 0, 9, 16, 6, 18, 19] swap(8, 18)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 15, 0, 9, 16, 17, 18, 19] swap(9, 17)
[14, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 0, 15, 16, 17, 18, 19] swap(13, 15)
[0, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(0, 14)
[0, 3, 11, 10, 4, 1, 2, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(0, 0)
[0, 3, 2, 10, 4, 1, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(2, 6)
[0, 3, 2, 1, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(3, 5)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(1, 3)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(1, 1)
[0, 1, 2, 3, 4, 10, 11, 12, 5, 6, 7, 8, 13, 9, 14, 15, 16, 17, 18, 19] swap(4, 4)
[0, 1, 2, 3, 4, 10, 9, 12, 5, 6, 7, 8, 13, 11, 14, 15, 16, 17, 18, 19] swap(6, 13)
[0, 1, 2, 3, 4, 10, 9, 8, 5, 6, 7, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(7, 11)
[0, 1, 2, 3, 4, 7, 9, 8, 5, 6, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 10)
[0, 1, 2, 3, 4, 7, 6, 8, 5, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(6, 9)
[0, 1, 2, 3, 4, 7, 6, 5, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(7, 8)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 7)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(5, 5)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 11, 14, 15, 16, 17, 18, 19] swap(8, 8)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 16, 17, 18, 19] swap(12, 13)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(11, 12)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(15, 15)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(16, 16)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(17, 17)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] swap(18, 18)

根据打印结果可以逐步分析快排算法的执行过程,明确知晓每一次交换的数据。

数据可视化

打印log固然可以分析算法的执行流程,但是不够直观,所以想着用可视化工具进一步处理,以动态图片形式显示快排过程。为此,只需在交换数据的函数__swap中使用matplotlib的柱状图绘制交换完成后的数组即可。

import matplotlib.pyplot as plt
...

class QuickSort(object):
    def __init__(self, debug=False, save_fig=False):
        ...
        self.save_fig = save_fig
        self.fig, self.ax = plt.subplots()
        # open interactive mode of matplot
        plt.ion()

        if self.save_fig:
            self.path = './images/{0}'.format(time.strftime('%Y%m%d_%H%M%S'))
            os.makedirs(self.path)

    def __swap(self, data, lo, hi):
        ...
        self.__plot_figure(data, lo, hi, show_swap=True)

    def __plot_figure(self, data, lo=0, hi=0, show_swap=False):
        '''plot and save figure'''
        self.ax.clear()
        self.ax.set_title('data quicksort')
        self.ax.bar(range(len(data)), data, label='data')
        if show_swap:
            self.ax.bar([lo, hi], [data[lo], data[hi]], color='red', label='swap')
        plt.legend()
        plt.pause(0.001)

        if self.save_fig:
            plt.savefig('{0}/{1}.png'.format(self.path, self.swap_times))

需要注意的几点是:

  1. matplotlib具有两种绘图模式,阻塞(block)模式和交互(interactive)模式,阻塞模式必须等待当前绘图窗口关闭方才执行后续程序,而交互模式则无需等待。为了动态显示排序过程,自然选择交互模式,所以初始化时调用ion()函数打开交互模式
  2. 相比于其它样式的图表,使用柱状图bar能够更直观显示数据大小及变化过程
  3. 在每次重绘图表时,需要清空原有图表
  4. 必须调用pause函数予以等待,否则可能出现无法显示图表的情况,等待时长自定
  5. 为了突出显示每次交换的两个数据,可以使用红色图表单独绘制交换数据
  6. 使用matplotlib函数库中的savefig可以将图表为至本地图片文件,为后续生成gif图片做准备

data bar

生成GIF动图

有了前面保存好的图片,使用imageio库的append函数及mimsave即可生成gif图片。图片间隔时间由mimsave函数的duration参数决定。

# main.py
import imageio
from quick_sort import QuickSort

import os
from os.path import join


def save_gif(path, gif_name):
    if not os.path.exists(path) or len(os.listdir(path))==0:
        return

    images = []
    for file_name in range(len(os.listdir(path))):
        file_path = join(path, '{}.png'.format(file_name))
        images.append(imageio.imread(file_path))
    imageio.mimsave(join(path, gif_name), images, 'GIF', duration=0.2)

quick sort

至此,便完成了快排的算法实现及其可视化。

完整源码

代码已上传至github Python-demos algorithm目录

  • quick_sort.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import time
import os

class QuickSort(object):
    '''Quick sort algorithm'''
    def __init__(self, debug=False, save_fig=False):
        self.debug = debug
        self.save_fig = save_fig
        self.fig, self.ax = plt.subplots()
        plt.ion()

        if self.save_fig:
            self.path = './images/{0}'.format(time.strftime('%Y%m%d_%H%M%S'))
            os.makedirs(self.path)

    def sort(self, data):
        self.swap_times = 0
        self.__plot_figure(data)
        # set the largest element to the end
        # self.__swap(data, data.index(max(data)), len(data) - 1)
        self.__sort(data, 0, len(data) - 1)
        return self.swap_times, self.path

    def __swap(self, data, lo, hi):
        data[lo], data[hi] = data[hi], data[lo]
        self.swap_times += 1

        if self.debug:
            print('\t{0} swap({1}, {2})'.format(data, lo, hi))

        self.__plot_figure(data, lo, hi, show_swap=True)

    def __plot_figure(self, data, lo=0, hi=0, show_swap=False):
        '''plot and save figure'''
        self.ax.clear()
        self.ax.set_title('data quicksort')
        self.ax.bar(range(len(data)), data, label='data')
        if show_swap:
            self.ax.bar([lo, hi], [data[lo], data[hi]], color='red', label='swap')
        plt.legend()
        plt.pause(0.001)

        if self.save_fig:
            plt.savefig('{0}/{1}.png'.format(self.path, self.swap_times))

    def __sort(self, data, lo, hi):
        if lo >= hi:
            return
        key = self.__partition(data, lo, hi)
        self.__sort(data, lo, key - 1)
        self.__sort(data, key + 1, hi)

    def __partition(self, data, lo, hi):
        '''partition array'''
        i = lo
        j = hi
        v = data[lo] # slicing element
        while True:
            # find one element that larger than v scan from left to right(→)
            i += 1
            while data[i] < v:
                # below judge can dropped if the end element is the largest
                if i == hi:
                    break
                i += 1

            # find one element that smaller than v scan from right to left(←)
            while v < data[j]:
                j -= 1

            if i >= j:
                break
            self.__swap(data, i, j)
        self.__swap(data, lo, j)
        return j

  • main.py
#!/bin/env python
# -*- encoding: utf-8 -*-

import time
import random
import imageio
from quick_sort import QuickSort

import os
from os.path import join


def save_gif(path, gif_name):
    if not os.path.exists(path) or len(os.listdir(path))==0:
        return

    images = []
    for file_name in range(len(os.listdir(path))):
        file_path = join(path, '{}.png'.format(file_name))
        images.append(imageio.imread(file_path))
    imageio.mimsave(join(path, gif_name), images, 'GIF', duration=0.2)


def main():
    data = []
    random.seed(time.time())
    random.shuffle(data)

    print('source: {0}'.format(data))
    start = time.time()
    qs = QuickSort(debug=False, save_fig=True)
    swap_times, fig_path = qs.sort(data)
    save_gif(fig_path, 'quick_sort.gif')
    stop = time.time()
    print('result: {0}\n'.format(data))

    print('----------------------------------')
    print('swap times: {0}'.format(swap_times))
    print('spend time: {0}s'.format(stop - start))
    print('image path: {0}'.format(fig_path))
    print('----------------------------------')


if __name__ == '__main__':
    main()