ランダムなデータをソートするアルゴリズムは動画にして観察するとなぜだか癒されます。マージソートをプログラミングするためには再帰処理も覚える必要があり勉強にもなります。という事で、ここではソートの中でも一際人気なマージソートをPythonで可視化しながら学びます。
こんにちは。wat(@watlablog)です。マージソートの学習は再帰処理も学べる良い題材です!ここでは可視化コードを紹介しながら覚えていきます!
マージソートの概要
ソートアルゴリズムと計算量
たいていのプログラミング言語には標準でソート(並べ替え)の関数が付属していますが、あえてそのアルゴリズムを学ぶ事でプログラミングにおける考え方を習得します。
当ブログでは既に「Pythonコードと図解で理解するバブルソートのアルゴリズム」でバブルソート、「Pythonで選択ソートのアルゴリズムを実装する方法【動画付】」で選択ソート、「【動画付き】Pythonで挿入ソートのアルゴリズムを実装する方法」で挿入ソートと3つのソートアルゴリズムを学びました。
これらのソートはいずれも計算量が\(O(N^{2})\)と遅いソートに分類されますが、アルゴリズムがシンプルで実装しやすいために少ないデータ量においては選択肢の一つになるでしょう。
シリーズ第4段となるマージソートは最悪計算量が\(O(N \log N)\)と二乗オーダーではない高度なソートに分類されます。
高度な分、マージと再帰処理というテクニックを覚える必要はありますが、プログラミングの学習にはちょうど良い題材です。
また、記事の最後ではマージソートの処理を動画で可視化する事によって理解を深めるとともに癒されます。是非最後までご覧ください。
Pythonコードで「マージ」を体験する
マージソートは完全にランダムなデータを並べ替えるアルゴリズムですが、「マージ」という別途定義されたアルゴリズムを使います。
マージ(Merge)とは、日本語で併合するとか混ぜるという意味があります。マージアルゴリズムと呼ぶ時は、2つのソート済みデータをまとめて1つのソート済みデータに変換するアルゴリズムを指します。
図解すると以下のようになります。
既にソートされたリストであれば、左から順番に要素を抽出して大小を比較していけば全体がソートされた1つのリストになります。
以下は手順の図解です。リストの長さが異なっている場合、最後は余った分を全体リストに後付けすれば問題ありません。
これをPythonでプログラミングすると以下のコードとなります。Pythonはちょっと文法を覚えるだけでアルゴリズムをそのまま理解する事ができるため、コード内のコメントで解説は十分と思います。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import numpy as np #マージする関数 def merge(x, y): # 左側lと右側rを順番に確認し、小さい方をリストに加える # どちらかの配列長が参照終了したらループを抜ける l, r = 0, 0 list = [] while l < len(x) and r < len(y): if x[l] <= y[r]: list.append(x[l]) l += 1 else: list.append(y[r]) r += 1 # 左側と右側配列の長い方(残った方)をリストに加える if l < r: list.extend(x[l:]) else: list.extend(y[r:]) return list # merge関数を実行 x = np.arange(0, 9, 2) y = np.arange(1, 14, 2) list = merge(x, y) print('initial array1=', x) print('initial array2=', y) print('merge result=', list) |
コードを実行すると以下の結果を得ます。既にソートされたリスト(ndarrayですが)を作成し、マージを行う事で全体がソートされた1つのリストができました。
1 2 3 |
initial array1= [0 2 4 6 8] initial array2= [ 1 3 5 7 9 11 13] merge result= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13] |
Pythonコードで「再帰処理」を体験する
マージソートをプログラミングするためには、マージアルゴリズムに加えて再帰処理(Recursive processing)を覚える必要があります。
再帰処理とは、あるものを定義する時にそのもの自身を参照する処理の事です。こちらもまずはPythonコードで体験してみましょう。
以下のコードが再帰処理の例です。
過去の基本情報技術者試験の問いを参照しました。この再帰関数は終了条件を指定している所がキーポイントです。終了条件を指定しないと再帰処理は無限ループに陥ってしまいます。
1 2 3 4 5 6 7 8 9 10 11 |
# 再帰関数の例 def f(x): # 終了条件 if x <= 1: return 1 # 終了条件を満たさない場合はreturnに自身の関数を指定 else: return x + f(x-1) y = f(10) print(y) |
1 |
55 |
これは処理をノートに書いていけばわかりやすくなりますが、最終的には初期値から1ずつ引いていった値を全て足すという処理になります。
再帰処理をコーディングする事でこのような処理も書く事ができ、一般に再帰処理を覚える事でプログラミングの幅がひろがると言われています。
マージソートの処理イメージ
マージアルゴリズムと再帰処理を体験した事で、マージソート(Merge sort)のアルゴリズムをプログラミングする事ができます。
マージソートは完全にランダムなリストを要素数が1になるまで分割します。この「要素数が1になるまで」という処理に再帰処理を使います。
要素数が1になったという事は、これはソート済リストになった事を意味するため、マージ処理を使う事ができます。
そのため分割したリストを逆順にマージしていけば、最終的に全体がソートされたリストを得る事ができます。
コードを紹介する前に、マージソートのイメージを図解しておきましょう。
ちなみにマージソートは同じ数が存在しても、元のリストと並び順が変わらない安定したソートにもなっています。
ついでに動画も見てみましょう。動画の作り方と、もっとスッキリする動画は記事の最後で…。
図にすると簡単に理解する事ができます。それではPythonでコーディングしていきましょう。
Pythonでマージソートするコード
以下がマージソートの全コードです。マージ関数merge()とマージソート関数merge_sort()を作成しました。マージソート関数内に再帰処理による分割を入れています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import random #マージする関数 def merge(x, y): # 左側lと右側rを順番に確認し、小さい方をリストに加える # どちらかの配列長が参照終了したらループを抜ける l, r = 0, 0 list = [] while l < len(x) and r < len(y): if x[l] <= y[r]: list.append(x[l]) l += 1 else: list.append(y[r]) r += 1 # 左側と右側配列の長い方(残った方)をリストに加える if l < r: list.extend(x[l:]) else: list.extend(y[r:]) return list def merge_sort(x): # 分割できなくなったら終了 if len(x) <= 1: return x # 再帰処理で分割をする mid = len(x) // 2 l = merge_sort(x[:mid]) r = merge_sort(x[mid:]) list = merge(l, r) return list # ランダムなリストを作成してマージソートを実行 x = random.sample(range(8), k=8) list = merge_sort(x) print('initial array:', x) print('merge sorted:', list) |
結果は以下。ランダムなリストを入力してソートされました。
1 2 |
initial array: [3, 4, 6, 5, 2, 0, 1, 7] merge sorted: [0, 1, 2, 3, 4, 5, 6, 7] |
Pythonでマージソート処理を動画にするコード
最後はシリーズ恒例、動画を作成するコードを紹介します。ソートが速いのでcreate_gif()関数をちょっといじっています(最後の状態を長く見せるために処理を追加)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import numpy as np import random import glob import os from PIL import Image from matplotlib import pyplot as plt # GIFアニメーションを作成 def create_gif(in_dir, out_filename): path_list = sorted(glob.glob(os.path.join(*[in_dir, '*']))) # ファイルパスをソートしてリストする imgs = [] # 画像をappendするための空配列を定義 # ファイルのフルパスからファイル名と拡張子を抽出 for i in range(len(path_list)): print(len(path_list)) img = Image.open(path_list[i]) # 画像ファイルを1つずつ開く # 最後の画像だけ10枚追加して余韻を残す。 if i == len(path_list) - 1: for j in range(10): imgs.append(img) imgs.append(img) # 画像をappendで配列に格納していく # appendした画像配列をGIFにする。durationで持続時間、loopでループ数を指定可能。 imgs[0].save(out_filename, save_all=True, append_images=imgs[1:], optimize=False, duration=50, loop=0) return # プロットする関数 def plot(data, l, list, save_dir): #print(data.index(l[0])) start = data.index(l[0]) end = start + len(list) data[start:end] = list #print(data) # ここからグラフ描画 i = len(glob.glob1(save_dir, "*.png")) print(i) # フォントの種類とサイズを設定する。 plt.rcParams['font.size'] = 14 plt.rcParams['font.family'] = 'Times New Roman' # 目盛を内側にする。 plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' # グラフの上下左右に目盛線を付ける。 fig = plt.figure() ax1 = fig.add_subplot(111) ax1.yaxis.set_ticks_position('both') ax1.xaxis.set_ticks_position('both') # 軸のラベルを設定する。 ax1.set_xlabel('x') ax1.set_ylabel('y') ax1.set_xticks(np.arange(0, len(x) + 1, len(x)/10)) ax1.set_xlim(0, len(x) + 1) ax1.set_ylim(0, len(x) + 1) # データプロットの準備。 ax1.bar(np.arange(1, len(data) + 1, 1), data, label='data') plt.text(1, int(len(x) * 0.9), 'count=' + str(i), fontsize=20) # グラフを表示する。 fig.tight_layout() fig.legend() # dirフォルダが無い時に新規作成 if os.path.exists(save_dir): pass else: os.mkdir(save_dir) # 画像保存パスを準備 path = os.path.join(*[save_dir, str("{:05}".format(i)) + '.png']) # 画像を保存する plt.savefig(path) plt.close() return data # マージする関数 def merge(x, y): # 左側lと右側rを順番に確認し、小さい方をリストに加える # どちらかの配列長が参照終了したらループを抜ける l, r = 0, 0 list = [] while l < len(x) and r < len(y): if x[l] <= y[r]: list.append(x[l]) l += 1 else: list.append(y[r]) r += 1 # 左側と右側配列の長い方(残った方)をリストに加える if l < r: list.extend(x[l:]) else: list.extend(y[r:]) return list # マージソートする関数 def merge_sort(x, count): # 可視化用に初期配列を保持しておく(再帰処理でも失われないように) if count == 0: x_ = x.copy() else: x_ = count #print(x_) # 分割できなくなったら終了 if len(x) <= 1: return x # 再帰処理で分割をする mid = len(x) // 2 l = merge_sort(x[:mid], x_) r = merge_sort(x[mid:], x_) list = merge(l, r) count = plot(x_, l, list, 'img_merge-sort') return list # ランダムなリストを作成してマージソートを実行後GIF動画を作る。 x = random.sample(range(100), k=100) list = merge_sort(x, 0) print('initial array:', x) print('merge sorted:', list) create_gif('img_merge-sort', 'movie-merge-sort.gif') |
以下が結果です。100個のデータが揃っていく様は気持ちが良いです。是非マシンスペックの許す限りのデータ数でやってみてください。
まとめ
この記事ではマージアルゴリズムと再帰処理を学び、マージソートアルゴリズムのコードを紹介しました。
ソート1つとっても様々な手法があり、人類のアイデアの多さに脱帽しますね。
ちょっと高度なマージソートも理解する事ができました!
Twitterでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!