kivantium活動日記

プログラムを使っていろいろやります

ソートアルゴリズム高速化への道

先日、アルゴリズムの授業でソートのアルゴリズムをいくつか習いました。ソートアルゴリズムの名前と原理くらいは聞いたことがありましたが、実装したことはなかったのでいい機会だと思い実装してみることにしてみました。ただ実装するだけでは面白くないので高速化の限界に挑戦してみたいと思います。

計測用プログラム

今回の計測では、ランダム値が入った配列のソートを100回行い、平均時間を各アルゴリズムに競わせるというシンプルなルールにしました。プログラムは以下の通りです。
C++11で入ったメルセンヌ・ツイスタなどの機能を使っているので、ビルド時には-std=c++11を指定する必要があります。

実験に使用したパソコンのCPUはCore i3-3227U@1.90GHz、コンパイラgcc version 4.8.4で最適化オプションには-O3を指定しました。

#include <iostream>
#include <algorithm>
#include <climits>
#include <iomanip>
#include <chrono>
#include <random>
#include <cmath> 

using namespace std;

constexpr int size = 1000000; // 配列サイズ
int t[size];                  // 作業用配列
int ans[size];                // 答え合わせ用配列

constexpr int times = 100;   // 繰り返し回数
int record[times];           // 経過時間記録用配列

int main() {
    // メルセンヌ・ツイスタの準備
    // #include <random> が必要
    random_device rd;
    mt19937 mt(rd());
    uniform_int_distribution<int> rand(INT_MIN, INT_MAX);

    for (int loop = 0; loop < times; ++loop){
        // 配列にランダムな値を入れる
        for(auto & ref : ans) ref = rand(mt);
        // 作業用配列にコピー
        for(int i=0; i<size; ++i) t[i] = ans[i];
        // 答えの作成
        sort(ans, ans+size);

        // 時間計測 (#include <chrono> が必要)
        // 現在の時間
        auto start = chrono::system_clock::now();

        // ソート関数がここに入る
        MySortFunction(t, t+size);

        // 終了時間
        auto end = chrono::system_clock::now();
        // 経過時間
        auto dur = end - start;
        // マイクロ秒に変換
        auto usec = chrono::duration_cast<chrono::microseconds>(dur).count();
        // 記録
        record[loop] = usec;

        // 正しくソートされていることを確認
        bool flag = true;
        for (int i=0; i<size; ++i) {
            if(t[i] != ans[i]) flag = false;
        }
        // 間違っていたらその場で終了
        if (!flag){
            cout << "TEST FAILED!!" << endl;
            break;
        }
    }

    // 集計結果
    long int total_time = 0;
    for (int i=0; i<times; ++i) total_time += record[i];

    // 平均の表示
    cout << "average time: " << total_time/times/1000 << "."
         << setfill('0') << setw(3) << total_time/times%1000 << "msec" << endl;

    // 標準偏差の計算
    int sum = 0;
    for (int i=0; i<times; ++i) sum += pow(record[i]-total_time/times, 2);
    cout << "standard deviation: " << sqrt(sum/times)/1000.0 << endl;
}

では、競技者(競技アルゴリズム?)たちを紹介していきます。

測定したアルゴリズムの紹介

バブルソート 〜かませ犬って呼ばないで〜

最初に紹介するバブルソートは、単純で分かりやすいためソートアルゴリズムの話題で最初に出てくる有名なアルゴリズムです。残念ながら、最初に出てくるのは他の高速アルゴリズムのすごさを際立たせるためなのではないかと思ってしまうくらいに遅いアルゴリズムです。
平均計算時間は O(N^2)ということで、おおよそ配列の長さの2乗に比例する時間がかかります。


詳細はWikipediaを読んでもらいたいのですが、一応説明しておきます。便宜上、配列の添字が小さい方を「左」とし左から右に向かって大きくなるように並べるとします。

  • 左端から順に全ての要素に関して隣接する要素と比較する
  • 左の要素の方が大きければ順番を入れ替える
  • 右端まで比較が終わったら右端の要素が最大値になっている
  • また左端から同じような比較・入れ替えを行っていくと今度は右から2番目までが大きさの順になる
  • これを繰り返していくと全ての要素が大きさの順に並ぶ

というようなアルゴリズムです。ソースコードは以下の通りです。

void swap (int *a, int *b){
    int tmp = *b;
    *b = *a;
    *a = tmp;
}

void BubbleSort (int* first, int* last){
    int size = last - first;
    for(int i=0; i<size; ++i){
        for(int j=1; j<size-i; ++j){
           if(first[j] < first[j-1]){
              swap(&first[j-1], &first[j]);
            }
        }
    }
}

挿入ソート 〜バブルソートなんかに負けない!〜

挿入ソートも計算量が O(N^2)アルゴリズムですが、バブルソートよりは高速になっています。

  • 左から順に整列データを作っていく
  • k番目の要素がk-1番目の要素よりも小さいときは、k番目の要素があるべき位置に入るように他の要素を右にずらしてスペースを確保してk番目の要素をそこに挿入する
  • 全ての要素が整列するまでこの操作を続ける

という感じのプログラムです。ソースコードは(Wikipediaのコピペなのですが)

void InsertSort(int *first, int *last){
    int size = last - first;
    for (int i=1; i<size; i++) {
        int tmp = first[i];
        if (first[i-1] > tmp) {
            int j = i;
            do {
                first[j] = first[j-1];
                j--;
            } while (j > 0 && first[j-1] > tmp);
            first[j] = tmp;
        }
    }
}

マージソート 〜マジで相当速いソート〜

マージソート

  • 配列を前半と後半に分ける
  • それぞれをソートする
  • ソートされたデータ列を新しい配列にきれいに並べる(マージ)

というようなアルゴリズムです。「それぞれをソートする」の操作でマージソートを呼び出す再帰的なアルゴリズムになっています。
分割の結果がある程度小さくなったら別のアルゴリズムに切り替えるという手法もありますが、ここでは1つになるまで細分化することにしています。

マージソートの計算時間は O(n \log n)で、通常の実装では配列の他に O(n)のメモリが必要になります。
ソースコードは以下の通りです。

void merge (int *first1, int *last1, int *first2, int *last2){
    int size1 = last1 - first1;
    int size2 = last2 - first2;
    int *result = new int[size1+size2];
    int i = 0;
    int j = 0;
    int index = 0;
    while (true){
        if (i < size1 && (first1[i] <= first2[j] || j >= size2)){
            result[index] = first1[i];
            ++i; ++index;
        }
        if (j < size2 && (first1[i] > first2[j] || i >= size1)){
            result[index] = first2[j];
            ++j; ++index;
        }
        if (i==size1 && j==size2){
            for (i=0; i<size1; ++i) first1[i] = result[i];
            for (j=0; j<size2; ++j) first2[j] = result[j+size1];
            delete[] result;
            return;
        }
    }
}
void MergeSort (int *first, int *last){
    int size = last - first;
    if(size <= 1) return;
    MergeSort(first, first+size/2);
    MergeSort(first+size/2, last);
    merge(first, first+size/2, first+size/2, last);
}

ヒープソート 〜最高でも O(n \log n)、最低でも O(n \log n)

ヒープソートは二分ヒープ木というデータ構造に配列の数字を入れていって順番に取り出すことで整列するアルゴリズムです。配列の値がどのようなものであっても計算量があまり変わらないのが特長です。
ここでは最大ヒープ(親ノードが常に子ノードよりも大きくなるような二分木)を使って、根にある最大値を順番に後ろに置いていくような実装にしました。
ソースコード i番目の要素の親が \frac{i-1}{2} i番目の要素の子が2i+12i+2になるように設定していることに注意して読むといいと思います。

void HeapSort(int *first, int *last){
    int size = last - first;
    // 配列を前から順にヒープ構造に追加する
    for (int i=1; i<size; ++i){
        int j = i; // ヒープの末尾に追加
        // 親ノードが子ノードより大きくなるように並び替え
        while (j > 0){
            if (first[(j-1)/2] < first[j]){
               swap(&first[(j-1)/2], &first[j]);
               j = (j-1)/2;
            } else {
               break;
            } 
        }
    }

    for (int i=size-1; i>0; --i) {
        swap(&first[0], &first[i]); // 根を後ろに送る
        int j = 0;
        int k = 0;
        // ヒープの条件を満たすように入れ替え
        while(true){
            int left = 2*j + 1;
            int right = 2*j + 2;
            if (left >= i) break;
            if (first[left] > first[k]) k = left;
            if ((right < i) && (first[right] > first[k])) k = right;
            if (k == j) break;
            swap (&first[k], &first[j]);
            j = k;
        }
    }
}

クイックソート 〜クイックの名は伊達じゃない〜

クイックソートはその名の通り高速であることで知られるアルゴリズムです。

  • ピボットと呼ばれる適当な数を選び、それより小さい数を前に、大きい数を後ろに移動する
  • 前と後ろを別々にソートする

という単純な方法です。「別々にソート」はクイックソート再帰的に呼び出すことで実現することが多いです。
ピボットの選択方法はかなり速度に影響するそうですが、ここでは一般的な配列の最初・真ん中・最後の中央値を取る方法を採用しました。ピボットに選ばれる数が配列の最小値であった場合は無限ループに陥るので回避するコードを入れてあります。

int partition(int *first, int *last, int pivot) {
    int l = 0;
    int r = last - first -1;
    while (true){
        while (first[l] < pivot) ++l;
        while (first[r] >= pivot) --r;
        if (l >= r) return l;
        swap(&first[l], &first[r]);
    }
}

void QuickSort(int *first, int *last){
    int size = last - first;
    if (size <= 1) return;
    // pivotの決定
    // 先頭・真ん中・末尾の中央値
    int a = first[0];
    int b = first[size/2];
    int c = first[size-1];
    int pivot;
    if ((a-b)*(b-c)*(c-a) != 0){
        pivot = max(min(a,b), min(max(a,b),c));
    } else { // 中央値が取れない場合は最初に見つかった異なる2つの数のうち大きい方
        bool flag = true;
        for (int i=1; i<size; ++i){
            if (first[i-1] != first[i]) {
                pivot = max(first[i-1], first[i]);
                flag = false;
                break;
            }
        }
        if (flag) return; // 全て同じならソート済みと判定
    }
    int k = partition(first, last, pivot); // 配列の分割
    QuickSort(first, first+k);
    QuickSort(first+k, last);
}

測定

これらのアルゴリズムで配列の長さを変えて実行時間を100回測定した平均を示します。単位はミリ秒でカッコ内は標準偏差です。
また、「標準ソート」はに入っているstd::sort()を指します。

アルゴリズム 1万 10万 100万 200万 400万 800万
バブルソート 214 (1) - - - - -
挿入ソート 26.8 (0.5) - - - - -
マージソート 2.108 (0.08) 24.05 (0.2) 269.9 (3) 570.6 (2) 1188 (50) 2367 (20)
ヒープソート 1.4 (0.2) 18.37 (0.1) 242.3 (5) 556.2 (4) 1282 (37) 2745 (3)
クイックソート 1.16 (0.03) 13.6 (0.1) 158.4 (2) 328.7 (3) 683.7 (3) 1400.3 (5)
標準ソート 0.847 (0.2) 9.787 (0.08) 115.7 (1) 242.9 (2) 506.3 (4) 1056.2 (4)

バブルソート・挿入ソートは10万以上の配列に対して測定すると時間が掛かり過ぎたので測定を行いませんでした。
高速なソート4つの100万回以上での実行時間をグラフにするとこんな感じです。

f:id:kivantium:20151103150429j:plain:w600
自分で実装した中ではクイックソートがその名の通り一番速いという結果になりました。しかし、残念ながらSTLのstd::sort()がどの実装よりも速いということが判明しました。
せっかく独自実装したのに標準ライブラリに速度で負けているようでは劣化した車輪の再発明でしかないので、高速化してstd::sort()よりも高速なソートを目指すことにしました。

高速化チャレンジ

XORスワップ

標準ライブラリはテンプレートを利用して型にとらわれない実装をしているので、それに勝つためにはint型に限定した高速手法を試してみるのはどうだろうと最初に思いつきました。ソートを行うときに最もたくさん行われる動作は二つの数字を入れ替えるswapですが、int型の場合にはXOR交換アルゴリズムと呼ばれる一時変数を使うことなくスワップを行う手法が存在します。一時変数を使わずに済めばレジスタ数やメモリアクセスの回数を削減できるので高速化できるはずだと思ってクイックソートのスワップ関数をXORスワップに書き換えてみました。

追記: >|a|<と>|b|<が同じポインタのときはswapせずにゼロクリアしてしまうのでポインタが等しくないことを最初に確認する必要がありました(2016/12/30)

void xor_swap(int *a, int *b){
    if(a == b) return;
    *a = *a ^ *b;
    *b = *a ^ *b;
    *a = *a ^ *b;
}

長さ100万で実験した結果、(以下出てくる数値は全て長さ100万の時の測定値とします)

  • 普通のスワップ: 158.4 ms(sd=2)
  • XORスワップ: 162.9 ms (sd=1)

ということで逆に速度が低下してしまったので、採用しないことにしました。
速くなる可能性はあっても遅くはならないだろうと思ったので意外でした。

XOR交換アルゴリズム - Wikipediaにもあるように最近のプロセッサでXOR swapを採用する有効性はないというかむしろ有害な可能性の方が高いようです……

インライン化

スワップ関数はサイズが小さいのでinline指定したら速度が上がる可能性があると考えて実験してみました。

158.4 ms (sd=2) → 154.3 ms (sd=4)

高速化したと言えるか微妙なラインです。一応平均値の検定をしてみると5%有意水準なので有効と言えそうです。
しかし、-Sオプションをつけて出力されるアセンブリを比較してみると、partition関数のアセンブリはinlineをつけてもつけなくても意味的には同じものを出力していました(アセンブリには詳しくないので間違っているかもしれません)

理由が分からないのに採用するのは科学的な態度ではないですが、速くなっているようなのでinline指定したほうがいいような気がします。

有意性の検定について

t値は\frac{158.4-154.3}{\sqrt{\frac{2^2+4^2}{100-1}}} = 9.1でp<0.05なので有意という結論を出しました。
t値が不当に大きい気がするので何か間違っている気もします。お気づきの方がいらっしゃったらご指摘してくださるととても嬉しいです。

STLのソースコードを参考にする

小手先のテクニックで劇的な高速化をするのは難しそうだと思ったので、STLのソースコードを参考にすることにしました。
今回使ったライブラリはlibstdc++なのでドキュメントを辿りながらどのような実装になっているのかを調べてみました。(解説サイトがあったので紹介しておきます→STLのsortの計算量 - SRM diary(Sigmar) - TopCoder部

一番効果があった変更はpartition関数の内部で使っていたローカル変数を無くすことでした。
もともと

int partition(int *first, int *last, int pivot) {
    int l = 0;
    int r = last - first -1;
    while (true){
        while (first[l] < pivot) ++l;
        while (first[r] >= pivot) --r;
        if (l >= r) return l;
        swap(&first[l], &first[r]);
    }
}

と定義していたpartitionを

int* partition2(int *first, int *last, int pivot) {
    while (true) {
        while (*first < pivot) ++first;
        --last;
        while (pivot < *last) --last;
        if (!(first < last)) return first;
        swap(first, last);
        ++first;
    }
}

と書き換えるだけで
155.6 ms (sd=1.8)→139.7 ms (sd=4)
と10ms近い高速化が実現できました。というわけで採用。

挿入ソートとのハイブリッド

Wikipediaのクイックソートのページに、長さがある程度以下になったらクイックソートから挿入ソートに切り替えた方が速くなるという記述がありました。再帰呼び出しのオーバーヘッドの分を考えると短い配列で挿入ソートの方が速いのは確かにありえそうです。というわけでQuickSort関数に

if (size <= N) {
   InsertSort(first, last);
   return;
}

というような記述を追加して試してみました

N=16 N=32 N=64 N=128
112.08 (0.8) 108.6 (3) 107.8 (0.8) 114.4 (1)

確かに速くなりました。N=32かN=64が最適そうです。なんとなくN=32を採用しました。

イントロソートの導入

さて、libstdc++のソート関数はイントロソート(英語版Wikipedia)というクイックソートとヒープソートを組み合わせたアルゴリズムを採用しています。これは、

  • 再帰呼び出しの深さが一定になるまでクイックソートを行う
  • 一定より深くなったらヒープソートを行う

というものです。こうすることでクイックソートの高速さを活かしつつ、クイックソートが遅くなる最悪ケースで極端に遅くなるのを防ぐことができます。さらにlibstdc++では配列の長さが一定以下になったら再帰呼び出しをやめて、最後に一括で挿入ソートを行っています。これで再帰呼び出しが深くなるのを防ぐことができます。これを試してみました。

int median3(int a, int b, int c){
    return max(min(a,b), min(max(a,b),c));
}

void introsort(int *first, int *last, int maxdepth){
    if (last - first <= N) return;
    if (maxdepth == 0) {
        HeapSort(first, last);
        return;
    }
    --maxdepth;
    int pivot = median3(*first,
  int median3(int a, int b, int c){
                        *(first+(last-first)/2),
                        *(last-1));
    int* p = partition2(first, last, pivot);
    introsort(first, p, maxdepth);
    introsort(p, last, maxdepth);
}
        
void IntroSort(int *first, int *t_end){
    int maxdepth = log2(t_end-t) * depth;
    introsort(t, t_end, maxdepth);
    InsertSort(t, t_end);
}

まず再帰呼び出しをやめる配列の長さは、depth=2(libstdc++と同じ)に固定して

N=8 N=16 N=32 N=64 N=128
120.1 (0.8) 114.6 (1) 109.4 (0.8) 109.0 (1) 110.6 (4)

だったのでN=32を採用しました。(libstdc++では16を使っている)

再帰呼び出しの最大深さはlog2(N)のdepth倍として、

depth=1 depth=2 depth=3 depth=4 depth=5 depth=10
126.7 (2) 109.4 (0.8) 110.3 (0.7) 110.2 (0.7) 110.2 (0.8) 110.6 (1)

だったのでlibstdc++と同じlog2(N)*2が最適なようです。

クイックソート+挿入ソートの最速値が108.6 msで、今作ったイントロソートの最速値が109.4 msなので最悪値のパフォーマンスを考えてイントロソートを採用することにしました。

並列化

ここまでの調整で、std::sort()の115.7 msと同じか少し上回る速度を出すことに成功しました。アルゴリズムをこれ以上変えずに高速化する簡単な方法として並列化があります。クイックソートでは分割した部分について独立にソートを行うことができるので並列化に向いています。

C++11で新たに導入されたstd::threadを使ってイントロソートを並列化してみました。こんな感じです。
なお、threadのインクルードが必要になります。

void IntroSort(int *first, int *last){
    int maxdepth = log2(last-first) * 2;

    int pivot = median3(*first,
                        *(first+(last-first)/2),
                        *(last-1));

    int* p = partition2(first, last, pivot);

    int pivot1 = median3(*first,
                         *(first+(p-first)/2),
                         *(p-1));

    int* p1 = partition2(first, p, pivot1);

    int pivot2 = median3(*p,
                         *(p+(last-p)/2),
                         *last-1);

    int* p2 = partition2(p, last, pivot2);
    thread t1(introsort, first, p1, maxdepth);
    thread t2(introsort, p1, p, maxdepth);
    thread t3(introsort, p, p2, maxdepth);
    thread t4(introsort, p2, last, maxdepth);
    t1.join();
    t2.join();
    t3.join();
    t4.join();
    InsertSort(first, last);
}

4スレッドにしたのは僕の環境で立てられる最大数が4だったからです。最大スレッド数は

unsigned int n = thread::hardware_concurrency();
cout << "max thread: " << n << endl;

で確認できます。

コンパイルには-pthreadオプションが必要になるので、

g++ sort.cpp -std=c++11 -Wall -O3 -pthread

という感じです。

並列化を行ったイントロソートとstd::sort()の速度比較を示します。

アルゴリズム 1万 10万 100万 200万 400万 800万
標準ソート 0.85 (0.2) 9.79 (0.1) 115.7 (1) 242.9 (2) 506.3 (4) 1056.2 (4)
並列イントロソート 1.28 (0.3) 11.8 (3) 86.2 (9) 162.9 (15.6) 324.8 (32) 746.1 (109)

f:id:kivantium:20151103161721j:plain:w600

並列化を行うことで平均値で3割近い高速化を実現することができました。

まとめ

  • バブル・挿入・マージ・ヒープ・クイックソートの速度を比較した
  • ソートアルゴリズムの高速化の実験を行い、標準ライブラリよりも高速なソートを実現することができた

ここで使ったスレッドで並列化する手法の他にも、SIMDで並列化する方法などがあるそうでまだまだ高速化の余地は残っていそうです。

ここで示したコードについてもこうやったらもっと速くなるというテクニックをご存知の方がいらっしゃったら紹介してくれると泣いて喜びます。

最後に今回の検証でこれが最速という結論になったソートアルゴリズムを示して終わります。

inline void swap (int *a, int *b){
    int tmp = *b;
    *b = *a;
    *a = tmp;
}
void InsertSort(int *first, int *last){
    int size = last - first;
    for (int i=1; i<size; i++) {
        int tmp = first[i];
        if (first[i-1] > tmp) {
            int j = i;
            do {
                first[j] = first[j-1];
                j--;
            } while (j > 0 && first[j-1] > tmp);
            first[j] = tmp;
        }
    }
}
void HeapSort(int *first, int *last){
    int size = last - first;
    for (int i=1; i<size; ++i){
        int j = i;
        while (j > 0){
            if (first[(j-1)/2] < first[j]){
               swap(&first[(j-1)/2], &first[j]);
               j = (j-1)/2;
            } else {
               break;
            } 
        }
    }

    for (int i=size-1; i>0; --i) {
        swap(&first[0], &first[i]);
        int j = 0;
        int k = 0;
        while(true){
            int left = 2*j + 1;
            int right = 2*j + 2;
            if (left >= i) break;
            if (first[left] > first[k]) k = left;
            if ((right < i) && (first[right] > first[k])) k = right;
            if (k == j) break;
            swap (&first[k], &first[j]);
            j = k;
        }
    }
}

int* partition(int *first, int *last, int pivot) {
    while (true) {
        while (*first < pivot) ++first;
        --last;
        while (pivot < *last) --last;
        if (!(first < last)) return first;
        swap(first, last);
        ++first;
    }
}

inline int median3(int a, int b, int c){
    return max(min(a,b), min(max(a,b),c));
}

void introsort(int *first, int *last, int maxdepth){
    if (last - first <= 32) return;
    if (maxdepth == 0) {
        HeapSort(first, last);
        return;
    }
    --maxdepth;
    int pivot = median3(*first,
                        *(first+(last-first)/2),
                        *(last-1));
    int* p = partition(first, last, pivot);
    introsort(first, p, maxdepth - 1);
    introsort(p, last, maxdepth - 1);
}
        

void IntroSort(int *first, int *last){
    int maxdepth = log2(last-first) * 2;

    int pivot = median3(*first,
                        *(first+(last-first)/2),
                        *(last-1));

    int* p = partition(first, last, pivot);

    int pivot1 = median3(*first,
                         *(first+(p-first)/2),
                         *(p-1));

    int* p1 = partition(first, p, pivot1);

    int pivot2 = median3(*p,
                         *(p+(last-p)/2),
                         *last-1);

    int* p2 = partition(p, last, pivot2);
    thread t1(introsort, first, p1, maxdepth);
    thread t2(introsort, p1, p, maxdepth);
    thread t3(introsort, p, p2, maxdepth);
    thread t4(introsort, p2, last, maxdepth);
    t1.join();
    t2.join();
    t3.join();
    t4.join();

    InsertSort(first, last);
}

広告コーナー