読者です 読者をやめる 読者になる 読者になる

kivantium活動日記

プログラムを使っていろいろやります

全探索によるニューラルネットワーク最適化の実験

先月に[1602.02830] Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1というニューラルネットワークを二値化して計算する論文が発表されました。(訂正:2つの論文があるように記述していましたが、両者は同じチームによる単なるバージョン違いだったようです)
2値化して十分な精度が出るのであれば専用ハードウェアを構成することでニューラルネットワークの計算を大幅に高速化できる可能性があります。また、演算器を100個並べられるなら、どんなソートアルゴリズムを使う?:Fluentd、Memcached、IoT、ドローン、機械学習、映像解析――ソフトとハードを隔てる壁が壊れつつある今、ITエンジニアは現実的に何ができるようになるのか - @ITではFPGAニューラルネットワークのパラメータを全探索することで最適化できる可能性について言及されていました。
この記事では、重みを二値化したニューラルネットワークについて全探索でパラメータを調整することで誤差逆伝搬に変わる新しいニューラルネットワーク最適化の可能性について検討してみたいと思います。

Binarized Neural Networks

二値化の方法について[1602.02830] Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1の論文を眺めてみます(こちらの論文にしたのはラストオーサーがBengioだという単なる権威主義です)どうせ全探索するのでForwardの計算だけ見ることにします。論文中のアルゴリズムを見ると
f:id:kivantium:20160314230248p:plain
のようになっているので、重みとBatch Normalization後の出力値について二値化を行うニューラルネットワークを構成していることが分かります。

Batch Normalization

Batch Normalizationは[1502.03167] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shiftで提案された方法です。単純な方法ですが効果はかなり強力らしく、いろいろなところで採用されているようです。Chainerを使ってコンピュータにイラストを描かせる - Qiitaではイラスト生成に使われるネットワークのDCGANで一番重要なのはBatch Normalizationの導入だったという印象が語られています。

Batch Normalizationの処理は
f:id:kivantium:20160314231051p:plain
という感じで、普通の正規化を行った後で適当な変換を行います。
この実験では \gamma=1,\ \ \beta=0と決め打ちしパラメータの学習は行わないことにします。

実験の設定

全探索するためあまりパラメータは増やせないので小規模なデータセットとしてirisを使います。入力次元が4・クラス数が3なので隠れ層を3つに設定すれば重みは4*3+3*3=21なので、探索対象は2^{21}\sim 1000000程度なので現実的な時間で全探索が終わります。

全体の流れとしては

  • 入力を正規化する
  • 0または1の二値の重みを掛けて隠れ層を計算する
  • 隠れ層を正規化する
  • 0または1の二値の重みを掛けて出力を計算する
  • 出力が最大の要素番号のラベルに分類する

を全ての重みについて実行し、訓練データに対する正解率が最も高かった重みを使ってテストデータの分類を行うというものになっています。

比較対象としてk-NN法(k=1)を使った分類の結果を求めています。

以下に実際に使ったソースコードを示します。

#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

// 平均と分散を指定して正規化を行う関数
void Normalize(vector<vector<double> > &minibatch, vector<double> &mean, vector<double> &var) {
    for(int i=0; i<minibatch.size(); ++i) {
        for(int j=0; j<minibatch[0].size(); ++j) 
            minibatch[i][j] = (minibatch[i][j]-mean[j]) / var[j];
    }
}
// バッチ正規化を行い平均と分散を返す関数
void BatchNormalization(vector<vector<double> > &minibatch, vector<double> &mean, vector<double> &var) {
    // 平均の計算
    for(int i=0; i<minibatch.size(); ++i) {
        for(int j=0; j<minibatch[0].size(); ++j) 
            mean[j] +=  minibatch[i][j];
    }
    for(int i=0; i<mean.size(); ++i) mean[i] /= minibatch.size();

    // 分散の計算
    for(int i=0; i<minibatch.size(); ++i) {
        for(int j=0; j<minibatch[0].size(); ++j) 
            var[j] += pow(minibatch[i][j]-mean[j], 2);
    }
    const double epsilon = 1e-5; // divided by zeroを防ぐために小さい値を足す
    for(int i=0; i<var.size(); ++i) var[i] = sqrt(var[i]/minibatch.size()+epsilon);

    // 正規化
    Normalize(minibatch, mean, var);
}

// バッチ正規化を行う関数(平均と分散は捨てる)
void BatchNormalization(vector<vector<double> > &minibatch) {
    vector<double> mean(minibatch[0].size(), 0);
    vector<double> var(minibatch[0].size(), 0);
    BatchNormalization(minibatch, mean, var);
}

int main(void){
    const int train_size = 120; // 訓練データ
    const int test_size = 30;   // テストデータ
    const int input_dim = 4;    // 入力の次元
    const int hidden_dim = 3;   // 隠れ層の次元
    const int output_dim = 3;   // 出力の次元

    // データの読み込み
    vector<vector<double>> train_data(train_size, vector<double>(input_dim));
    vector<int> train_label(train_size);

    vector<vector<double>> test_data(test_size, vector<double>(input_dim));
    vector<int> test_label(test_size);

    for(int i=0; i<train_size; ++i) {
        for(int j=0; j<input_dim; ++j) 
            cin >> train_data[i][j];
        cin >> train_label[i];
    }
    for(int i=0; i<test_size; ++i) {
        for(int j=0; j<input_dim; ++j) 
            cin >> test_data[i][j];
        cin >> test_label[i];
    }

    // 入力の正規化
    vector<double> mean(input_dim, 0);
    vector<double> var(input_dim, 0);
    BatchNormalization(train_data, mean, var);

    // 隠れ層の値を入れるvector
    vector<vector<double> > hidden_layer(train_size, vector<double>(hidden_dim));
    // 出力を入れるvector
    vector<double> output(output_dim);
    // 訓練データで最も良い結果を出すパラメータを保存する変数
    int max = 0;
    int bestw1, bestw2;

    // パラメータ全探索
    for(int w1bit=1; w1bit<(1<<hidden_dim*input_dim); ++w1bit) {
        // 各データについて隠れ層の値を計算
        for(int data=0; data<train_size; ++data) {
            for(int i=0; i<hidden_dim; ++i) {
                hidden_layer[data][i] = 0;
                for(int j=0; j<input_dim; ++j) {
                    hidden_layer[data][i]
                    += ((w1bit >> (i*hidden_dim+j))&1) * train_data[data][j];
                }
            }
        }

        // 隠れ層に対するバッチ正規化
        BatchNormalization(hidden_layer);

        for(int w2bit=1; w2bit<(1<<hidden_dim*output_dim); ++w2bit) {
            int correct = 0; // 正解数を保存する変数
            // 各データについて出力を計算
            for(int data=0; data<train_size; ++data) {
                for(int i=0; i<output_dim; ++i) {
                    output[i] = 0;
                    for(int j=0; j<hidden_dim; ++j) {
                        output[i]
                        += ((w2bit >> (i*output_dim+j))&1) * hidden_layer[data][j];
                    }
                }
                // 出力が最大の要素を調べる
                auto iter = max_element(output.begin(), output.end());
                size_t index = distance(output.begin(), iter);
                if (index == train_label[data]) correct++;
            }
            // 正解数が最大なら更新
            if(correct > max) {
                max = correct;
                bestw1 = w1bit;
                bestw2 = w2bit;
            }
        }
    }

    // テスト用の隠れ層
    vector<vector<double> > hidden_layer_test(test_size, vector<double>(hidden_dim));

    // 訓練データと同じパラメータで正規化
    Normalize(test_data, mean, var);
    
    // 各テストデータについて隠れ層を計算
    for(int data=0; data<test_size; ++data) {
        for(int i=0; i<hidden_dim; ++i) {
            hidden_layer_test[data][i] = 0;
            for(int j=0; j<input_dim; ++j) {
                hidden_layer_test[data][i]
                += ((bestw1 >> (i*hidden_dim+j))&1) * test_data[data][j];
            }
        }
    }

    // 隠れ層のバッチ正規化
    BatchNormalization(hidden_layer_test);

    // 出力の計算
    int correct = 0;
    for(int data=0; data<test_size; ++data) {
        for(int i=0; i<output_dim; ++i) {
            output[i] = 0;
            for(int j=0; j<hidden_dim; ++j) {
                output[i]
                += ((bestw2 >> (i*output_dim+j))&1) * hidden_layer_test[data][j];
            }
        }
        // 正解数のカウント
        auto iter = max_element(output.begin(), output.end());
        size_t index = distance(output.begin(), iter);
        if(index==test_label[data]) correct++;
    }
    cout << "best train result: " << max << "/" << train_size << endl;
    cout << "test result: " << correct << "/" << test_size << endl;

    // k-nearest neighbor(k=1)による分類
    correct = 0;
    for(int data=0; data<test_size; ++data) {
        double min = 1e10;
        int index = 0;
        for(int search=0; search<train_size; ++search) {
            double dist = 0;
            for(int i=0; i<input_dim; ++i) {
                dist += pow(train_data[search][i]-test_data[data][i], 2);
            }
            if(dist < min) {
                min = dist;
                index = search;
            }
        }
        if(train_label[index] == test_label[data]) correct++;
    }
    cout << "k-NN result: " << correct << "/" << test_size << endl;
}

コンパイルコマンドと出力は以下の通りです。(data.txtは記事の最後)

g++ iris.cpp -std=c++11 -O2
time ./a.out < data.txt
best train result: 104/120
test result: 24/30
k-NN result: 28/30

real	0m7.130s
user	0m7.120s
sys	0m0.008s

全探索で決定したパラメータを使った正解率は80%と重みを2値に限定した割にはそこそこの結果を出しているように思えますが、k-NN法による正解率の方が高くあまり実用的な方法ではないようです。k-NN法も含めて7sで学習が終了しており何も工夫していなくてもこのくらいの範囲なら十分高速に計算できることも分かります。

今後の展望

今回は簡単のために重みを0, 1の2値に限定しましたが、-1, 1の2値にした場合やBatch Normalizationのパラメータ調整を加えた場合・活性化関数を加えた場合などいろいろ調整できるところは残っているように思います。

また現状の全探索する方法では2^Nの時間がかかってしまいMNISTなどのもっと次元の高いタスクでは全く使い物になりませんが、SATなどの高速な解法が分かっている別の組み合わせ最適化問題に帰着することができればもっと高速に局所解にはまることのない学習アルゴリズムに発展させることができるのではないかと思っています。もう少し二値ニューラルネットワークについて考えていきたいです。

実験に使ったデータ

https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.dataのデータを加工しました。

5.4 3.7 1.5 0.2 0
4.8 3.4 1.6 0.2 0
4.8 3.0 1.4 0.1 0
4.3 3.0 1.1 0.1 0
5.8 4.0 1.2 0.2 0
5.7 4.4 1.5 0.4 0
5.4 3.9 1.3 0.4 0
5.1 3.5 1.4 0.3 0
5.7 3.8 1.7 0.3 0
5.1 3.8 1.5 0.3 0
5.4 3.4 1.7 0.2 0
5.1 3.7 1.5 0.4 0
4.6 3.6 1.0 0.2 0
5.1 3.3 1.7 0.5 0
4.8 3.4 1.9 0.2 0
5.0 3.0 1.6 0.2 0
5.0 3.4 1.6 0.4 0
5.2 3.5 1.5 0.2 0
5.2 3.4 1.4 0.2 0
4.7 3.2 1.6 0.2 0
4.8 3.1 1.6 0.2 0
5.4 3.4 1.5 0.4 0
5.2 4.1 1.5 0.1 0
5.5 4.2 1.4 0.2 0
4.9 3.1 1.5 0.1 0
5.0 3.2 1.2 0.2 0
5.5 3.5 1.3 0.2 0
4.9 3.1 1.5 0.1 0
4.4 3.0 1.3 0.2 0
5.1 3.4 1.5 0.2 0
5.0 3.5 1.3 0.3 0
4.5 2.3 1.3 0.3 0
4.4 3.2 1.3 0.2 0
5.0 3.5 1.6 0.6 0
5.1 3.8 1.9 0.4 0
4.8 3.0 1.4 0.3 0
5.1 3.8 1.6 0.2 0
4.6 3.2 1.4 0.2 0
5.3 3.7 1.5 0.2 0
5.0 3.3 1.4 0.2 0
5.0 2.0 3.5 1.0 1
5.9 3.0 4.2 1.5 1
6.0 2.2 4.0 1.0 1
6.1 2.9 4.7 1.4 1
5.6 2.9 3.6 1.3 1
6.7 3.1 4.4 1.4 1
5.6 3.0 4.5 1.5 1
5.8 2.7 4.1 1.0 1
6.2 2.2 4.5 1.5 1
5.6 2.5 3.9 1.1 1
5.9 3.2 4.8 1.8 1
6.1 2.8 4.0 1.3 1
6.3 2.5 4.9 1.5 1
6.1 2.8 4.7 1.2 1
6.4 2.9 4.3 1.3 1
6.6 3.0 4.4 1.4 1
6.8 2.8 4.8 1.4 1
6.7 3.0 5.0 1.7 1
6.0 2.9 4.5 1.5 1
5.7 2.6 3.5 1.0 1
5.5 2.4 3.8 1.1 1
5.5 2.4 3.7 1.0 1
5.8 2.7 3.9 1.2 1
6.0 2.7 5.1 1.6 1
5.4 3.0 4.5 1.5 1
6.0 3.4 4.5 1.6 1
6.7 3.1 4.7 1.5 1
6.3 2.3 4.4 1.3 1
5.6 3.0 4.1 1.3 1
5.5 2.5 4.0 1.3 1
5.5 2.6 4.4 1.2 1
6.1 3.0 4.6 1.4 1
5.8 2.6 4.0 1.2 1
5.0 2.3 3.3 1.0 1
5.6 2.7 4.2 1.3 1
5.7 3.0 4.2 1.2 1
5.7 2.9 4.2 1.3 1
6.2 2.9 4.3 1.3 1
5.1 2.5 3.0 1.1 1
5.7 2.8 4.1 1.3 1
6.5 3.2 5.1 2.0 2
6.4 2.7 5.3 1.9 2
6.8 3.0 5.5 2.1 2
5.7 2.5 5.0 2.0 2
5.8 2.8 5.1 2.4 2
6.4 3.2 5.3 2.3 2
6.5 3.0 5.5 1.8 2
7.7 3.8 6.7 2.2 2
7.7 2.6 6.9 2.3 2
6.0 2.2 5.0 1.5 2
6.9 3.2 5.7 2.3 2
5.6 2.8 4.9 2.0 2
7.7 2.8 6.7 2.0 2
6.3 2.7 4.9 1.8 2
6.7 3.3 5.7 2.1 2
7.2 3.2 6.0 1.8 2
6.2 2.8 4.8 1.8 2
6.1 3.0 4.9 1.8 2
6.4 2.8 5.6 2.1 2
7.2 3.0 5.8 1.6 2
7.4 2.8 6.1 1.9 2
7.9 3.8 6.4 2.0 2
6.4 2.8 5.6 2.2 2
6.3 2.8 5.1 1.5 2
6.1 2.6 5.6 1.4 2
7.7 3.0 6.1 2.3 2
6.3 3.4 5.6 2.4 2
6.4 3.1 5.5 1.8 2
6.0 3.0 4.8 1.8 2
6.9 3.1 5.4 2.1 2
6.7 3.1 5.6 2.4 2
6.9 3.1 5.1 2.3 2
5.8 2.7 5.1 1.9 2
6.8 3.2 5.9 2.3 2
6.7 3.3 5.7 2.5 2
6.7 3.0 5.2 2.3 2
6.3 2.5 5.0 1.9 2
6.5 3.0 5.2 2.0 2
6.2 3.4 5.4 2.3 2
5.9 3.0 5.1 1.8 2
5.1 3.5 1.4 0.2 0
4.9 3.0 1.4 0.2 0
4.7 3.2 1.3 0.2 0
4.6 3.1 1.5 0.2 0
5.0 3.6 1.4 0.2 0
5.4 3.9 1.7 0.4 0
4.6 3.4 1.4 0.3 0
5.0 3.4 1.5 0.2 0
4.4 2.9 1.4 0.2 0
4.9 3.1 1.5 0.1 0
7.0 3.2 4.7 1.4 1
6.4 3.2 4.5 1.5 1
6.9 3.1 4.9 1.5 1
5.5 2.3 4.0 1.3 1
6.5 2.8 4.6 1.5 1
5.7 2.8 4.5 1.3 1
6.3 3.3 4.7 1.6 1
4.9 2.4 3.3 1.0 1
6.6 2.9 4.6 1.3 1
5.2 2.7 3.9 1.4 1
6.3 3.3 6.0 2.5 2
5.8 2.7 5.1 1.9 2
7.1 3.0 5.9 2.1 2
6.3 2.9 5.6 1.8 2
6.5 3.0 5.8 2.2 2
7.6 3.0 6.6 2.1 2
4.9 2.5 4.5 1.7 2
7.3 2.9 6.3 1.8 2
6.7 2.5 5.8 1.8 2
7.2 3.6 6.1 2.5 2

ちなみに、最近FPGAに関する新刊がいくつか出ているようです。

独自CPU開発で学ぶコンピュータのしくみ

独自CPU開発で学ぶコンピュータのしくみ

FPGAの原理と構成

FPGAの原理と構成