読者です 読者をやめる 読者になる 読者になる

kivantium活動日記

プログラムを使っていろいろやります

C++によるSMOを用いたSVMの実装

機械学習

機械学習の手法にはいろいろありますが、その中でもサポートベクトルマシン(SVM; support vector machine)は高い精度で知られる有名な手法です。
以前C++で多層パーセプトロンを実装したので、今度はSVMC++で実装してみました。kivantium.hateblo.jp

SVMの解説

実装する前にSVMの原理を説明します。PRML下巻7章を参考にしました。

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

SVMは2値分類に使われる手法で、入力に対してある関数を計算してその値の符号によってクラス分類をします。
入力x(ベクトルですが、うまく表記出来ないので普通の文字を使います)に対して重みベクトルwとバイアスパラメータbを用いてy(x)=w^T \phi(x)+bという関数を用意します。\phi(x)は入力を特徴空間の値に変換する関数ですが後で消えるので分からなければ無視してください。y(x)<0ならクラス-1、y(x)>0ならクラス1となるようにうまくwbの値を決めるのがSVMの目標となります。

ここで、どういうwを採用するのがいいかを考えます。
f:id:kivantium:20150624092839p:plain
右と左の境界線を比べるとなんとなく左の方がいい境界のように見えるかと思います。これは左の図のほうが境界線とデータが離れているからです。このように分類境界をデータからなるべく離すようなwbを求めるのがSVMの基本的アイデアになります。

データと分類境界の距離を求めてみます。特徴空間中のデータ点\phi(x)から分類境界に下ろした垂線の足を\phi(x^\perp)とすれば、wは法線ベクトルなので
{\displaystyle \phi(x)=\phi(x^\perp)+r\frac{w}{\|w\|}}と表わせます。このとき|r|が垂線の長さになります。
x^\perpは分類境界(y(x)=0)上の点なので {\displaystyle y(x^{\perp})=w^T \phi(x^\perp)+b = 0}が成立します。

{\displaystyle \phi(x^\perp)=\phi(x)-r\frac{w}{\|w\|}}を代入すれば、{\displaystyle r=\frac{w^T\phi(x)+b}{\|w\|}=\frac{y(x)}{\|w\|}}と分かります。

ここで、全てのデータがこの関数で正しく分類できると仮定し、データx_nに対応するクラスをt_nとおけば、
y(x_n)<0のときt_n=-1y(x_n)>0のときt_n=1なので)
{\displaystyle |r|=\frac{t_n y(x_n)}{\|w\|}}と書けます。
さらに「分類境界とデータをなるべく離す」というのを「境界に一番近いデータをなるべく境界から離す」と考えれば、SVMの目標は
{\displaystyle \mathop{\rm arg~max}\limits_{w,b}\frac{1}{\|w\|} \mathop{\rm{min}}\limits_{n} t_n (w^T\phi(x)+b)}
を求めることだと定式化できます。
ここで、wは向きを表すだけなので定数倍しても距離が変わらないことを利用すると境界に最も近い点についてt_n(w^T \phi(x_n)+b)=1となるようなうまいwを取れることが分かります。このとき先ほどの目標は
{\displaystyle \mathop{\rm arg~max}\limits_{w,b}\frac{1}{\|w\|}}
と非常に簡単な形にすることができます。これは
{\displaystyle \mathop{\rm arg~min}\limits_{w,b}\frac{1}{2}\|w\|^2}
と同等です。
この「境界に一番近いデータ」のことをサポートベクトルと呼び、サポートベクトルから境界までの距離をマージンと呼びます。


ここで、この式は全てのデータがこの関数で正しく分類できると仮定していることを思い出します。実際には全てのデータが正しく分類できるとは限らないので工夫が必要となります。その工夫がスラック変数の導入です。
各訓練データごとに\xi_nという変数を、データが正しく分類され、かつマージン境界の上または外側に存在する場合は\xi_n=0、それ以外の場合には\xi_n=|t_n-y(x_n)|として定義します。
すると、全てのデータについてt_n(w^T \phi(x_n)+b)\geq1-\xi_nが成立します。

もちろん誤分類は少ない方がいいので、\xi_nの総和はなるべく小さくしたいところです。そこで、目標を
{\displaystyle \mathop{\rm arg~min}\limits_{w,b}\frac{1}{2}\|w\|^2+C\sum_{n=1}^{N}\xi_n}
と変更します。Cはペナルティとマージンの大きさのトレードオフを決めるパラメータで事前に自分で決めておく必要があります。

この関数を最小化するためにはラグランジュの未定乗数法を使います。

ラグランジュの未定乗数法

ラグランジュの未定乗数法は変数に制約条件が与えられた下での関数の停留点を求めるときに使う手法です。理論的な解説はしませんが、ラグランジュの未定乗数法を用いるとJ個の制約条件g_j(x)\geq0とK個の制約条件h_k(x)\geq0の下でf(x)を最小化するにはラグランジュ関数
{\displaystyle L(x, \{a_j\}, \{\mu_k\}) = f(x)-\sum_{j=1}^{J}a_j g_j(x)-\sum_{k=1}^K \mu_kh_k(x)}

a_j\geq0かつa_j g_j(x)=0かつ\mu_k\geq0かつ\mu_kh_k(x)=0という条件のもとでxについては最小化・a, \muについては最大化すればよいということが分かります。
ラグランジュ関数に課せられる条件をKKT条件(Karush-Kuhn-Tucker condition)と呼びます。


先ほどの目的関数についてラグランジュ関数を求めると
{\displaystyle L(x, b, \xi, a, \mu) = \frac{1}{2}\|w\|^2+C\sum_{n=1}^{N}\xi_n-\sum_{n=1}^{N}a_n(t_ny(x_n)-1+\xi_n)-\sum_{n=1}^{N}\mu_n\xi_n}

となります。KKT条件は

  •  {\displaystyle a_n\geq0}
  •  {\displaystyle t_ny(x_n)-1+\xi_n\geq0}
  •  {\displaystyle a_n(t_ny(x_n)-1+\xi_n)=0}
  •  {\displaystyle \mu_n\geq0}
  •  {\displaystyle \xi_n\geq0}
  •  {\displaystyle \mu_n\xi_n=0}

で与えられます。(nは1からNまでの任意の整数)

この関数の停留点を求めるためにラグランジュ関数の偏微分が0になる値を計算すると
 {\displaystyle \frac{\partial L}{\partial w}=0 \Rightarrow w=\sum_{n=1}^{N}a_nt_n\phi(x_n)}
 {\displaystyle \frac{\partial L}{\partial b}=0 \Rightarrow \sum_{n=1}^{N}a_nt_n=0}
 {\displaystyle \frac{\partial L}{\partial \xi_n}=0 \Rightarrow a_n=C-\mu_n}

を得ます。これを最初のラグランジュ関数に代入すると双対形のラグランジュ関数
{\displaystyle W(a)=\sum_{n=1}^N a_n-\frac{1}{2}\sum_{n=1}^N\sum_{m=1}^Na_na_mt_nt_mk(x_n,x_m)}
を得ます。k(x_n, x_m)カーネル関数で、\phi(x)に応じて決まります。今回はガウスカーネル
{\displaystyle k(x, x')=\exp(\frac{\|x-x'\|^2}{2\sigma^2})}を使います。

a_n=C-\mu_nかつ\mu_n\geq0であることからa_n\leq Cです。a_n\geq 0と合わせるとこの関数の制約は

  • {\displaystyle 0\leq a_n\leq C}
  • {\displaystyle \sum_{n=1}^{N} a_nt_n=0}

となります。この制約のもとでW(a)を最大化できたらSVMの完成です。

これが解けたとすると、予測はwに代入して得られる関数
{\displaystyle y(x)=\sum_{n=1}^Na_nt_nk(x, x_n)+b}
の符号を見ることで行えます。
bはサポートベクトル(0< a_n < Cを満たす)の添字の集合をSとして、
{\displaystyle t_n(\sum_{m\in S}a_mt_mk(x_n, x_m)+b)=1}
を満たす値として計算できますが、数値計算の誤差を避けるため
{\displaystyle b=\frac{1}{N_M}\sum_{n\in M}(t_n-\sum_{m\in S}a_mt_mk(x_n, x_m)}
として求めることが多いです。(Mは0< a_n< Cを満たすデータ点の添字の集合)

ここで問題になるのがW(a)の最大化ですが、最大化を高速に行うアルゴリズムがSMOです。

SMO

SMOは2つのラグランジュ乗数に注目して最大化を繰り返すことで最終的にW(a)を最大化するアルゴリズムです。原著論文は
Fast Training of Support Vector Machines using Sequential Minimal Optimizationです。
このページの説明はこの論文を読んで実装ができることを目標にしたので、ここまで数式を追って来た人ならおそらく論文を読めると思います。擬似コードが非常に充実しているので擬似コードを読むだけで実装できます。一応簡単な解説をしておきます。

2つのラグランジュ乗数を選んで \alpha_1, \alpha_2とすれば、KKT条件より
 0\leq\alpha_1\leq C
 0\leq\alpha_2\leq C
 t_1\alpha_1+t_2\alpha_2=\rm{const.}
が成立するのでt_1, t_2(論文ではy_1, y_2になっているので注意)は以下の正方形の線分上にあります。
f:id:kivantium:20150624135717p:plain
この線分上に収まる範囲で最大値を与える \alpha_1, \alpha_2を求めれば2つについてラグランジュ関数が最大化されます。
更新式は
 {\displaystyle \alpha_2^{new} = \alpha_2^{old}-\frac{t_2(E_1-E_2)}{\eta}}
です。( \eta=2k(x_1, x_2)-k(x_1, x_1)-k(x_2, x_2), E_i = y(x_i)-t_i)

最適化する2つのラグランジュ乗数はKKT条件を満たしていないものを探して選ぶと良い結果が得られるようです。

これを繰り返すことでラグランジュ関数が最大化されます。

ソースコード

論文の擬似コードを素直に実装(一部手を抜きましたが)したのが次のコードです。
一応いくつかのデータで試してそれらしき結果が出ることは確認していますが、バグが残っている可能性があるので注意してください。

/* svm.h */
#ifndef LIBNN_SVM
#define LIBNN_SVM

#include <algorithm>
#include <fstream>
#include <cstdlib>
#include <ctime>
#include <cmath>

class svm {
    // dimension of vector
    int dim;
    // lagrange multiplier
    float *a;
    // train data
    float *point;
    // train label
    float *target;
    // error cache
    float *E;
    // threshold
    float b;
    float eps;
    float tol;
    float C;
    int N;

    // kernel function (gauss kernel)
    float kernel(float *x1, float *x2, float delta=1.0){
        float tmp = 0;
        for(int i=0; i<dim; ++i){
            tmp += (x1[i]-x2[i])*(x1[i]-x2[i]);
        }
        return exp(-tmp/(2.0*delta*delta));
    }
public:
    float predict(float *x){
        float tmp = 0;
        for(int i=0; i<N; ++i){
            tmp += a[i]*target[i]*kernel(x, point+i*dim);
        }
        return tmp-b;
    }
    int takeStep(int i1, int i2){
        if(i1 == i2) return 0;
        float alph1 = a[i1];
        float alph2 = a[i2];
        int y1 = target[i1];
        int y2 = target[i2];
        float E1 = E[i1];
        float E2 = E[i2];
        int s = y1*y2;
        // Compute L, H
        float L, H;
        if(y1!=y2){
            L = std::max((float)0.0, alph2-alph1);
            H = std::min(C, C+alph2-alph1);
        } else{
            L = std::max((float)0.0, alph1+alph2-C);
            H = std::min(C, alph1+alph2);
        }
        if(L==H) return 0;

        float k11 = kernel(point+i1*dim, point+i1*dim);
        float k12 = kernel(point+i1*dim, point+i2*dim);
        float k22 = kernel(point+i2*dim, point+i2*dim);

        float eta = 2*k12-k11-k22;

        float a1, a2;

        if(eta < 0){
            a2 = alph2 - y2*(E1-E2)/eta;
            if(a2<L) a2 = L;
            else if(a2>H) a2 = H;
        } else{
            a1 = a[i1];
            a2 = a[i2];
            float v1 = predict(point+i1*dim) - b - y1*a1*k11 - y2*a2*k12; // assume K12 = K21
            float v2 = predict(point+i2*dim) - b - y1*a1*k12 - y2*a2*k22;
            float Wconst = 0;
            for(int i=0; i<N; ++i){
                if(i!=i1 && i!=i2) Wconst+=a[i1];
            }
            for(int i=0; i<N; ++i){
                for(int j=0; j<N; ++j){
                    if(i!=i1 && i!=i2 && j!=i1 && j!=i2){
                        Wconst += target[i]*target[j]*kernel(point+i*dim, point+j*dim)*a[i]*a[j]/2.0;
                    }
                }
            }
            a2 = L;
            a1 = y1*a[i1]+y2*a[i2]-y2*L;
            float Lobj = a1+a2-k11*a1*a1/2.0-k22*a2*a2/2.0-s*k12*a1*a2/2.0
                -y1*a1*v1-y2*a2*v2+Wconst;
            a2 = H;
            a1 = y1*a[i1]+y2*a[i2]-y2*H;
            float Hobj = a1+a2-k11*a1*a1/2.0-k22*a2*a2/2.0-s*k12*a1*a2/2.0
                -y1*a1*v1-y2*a2*v2+Wconst;
            if(Lobj > Hobj + eps) a2 = L;
            else if(Lobj < Hobj - eps) a2 = H;
            else a2 = alph2;
        }

        if(a2 < 1e-8) a2 = 0;
        else if(a2 > C-1e-8) a2 = C;

        if(abs(a2-alph2) < eps*(a2+alph2+eps)) return 0;

        a1 = alph1+s*(alph2-a2);

        float b_old = b;
        float b1 = E1 + y1*(a1-a[i1])*k11 + y2*(a2-a[i2])*k12 + b;
        float b2 = E2 + y1*(a1-a[i1])*k12 + y2*(a2-a[i2])*k22 + b;
        if(b1==b2) b = b1;
        else b = (b1+b2)/2;
        float da1 = a1-a[i1];
        float da2 = a2-a[i2];
        for(int i=0; i<N; ++i){
            E[i] = E[i] + y1*da1*kernel(point+i1*dim, point+i*dim)
                +y2*da2*kernel(point+i2*dim, point+i*dim) + b_old - b;
        }

        a[i1] = a1;
        a[i2] = a2;

        return 1;
    }
    int examineExample(int i2){
        float y2 = target[i2];
        float alph2 = a[i2];
        float E2 = E[i2];
        float r2 = E2*y2;
        int i1 = 0;

        if((r2<-tol && alph2<C) || (r2>tol && alph2>0)){
            int number = 0;
            for(int i=0; i<N; ++i){
                if(a[i]!=0 || a[i]!=C) number++;
            }
            if(number > 1){
                float max = 0;
                for(int i=0; i<N; ++i){
                    if(abs(E[i]-E2) > max){
                        max = abs(E[i]-E2);
                        i1 = i;
                    }
                }
                if(takeStep(i1, i2)) return 1;
            }
            srand((unsigned)time(NULL));
            i1 = rand()%N;
            if(takeStep(i1, i2)) return 1;
        }
        return 0;
    }

    /* constructor
     * dimension: dimension of input
     * C: constant
     * */
    svm(int dimension, float argC=1.0){
        dim = dimension;
        C = argC;
        eps = 0.01;
        tol = 0.01;
    }

    // deconstructor
    ~svm(){
        delete[] a;
        delete[] point;
        delete[] target;
        delete[] E;
    }
    int test(int i){
        return target[i];
    }
    /* train svm
     * x: train data(size is dim*N)
     * t: train label(size is N)
     * size: data size
     */
    void train(float x[], int t[], int size){
        N = size;
        // initialize a
        a = new float[N];
        for(int i=0; i<N; ++i) a[i] = 0;
        point = new float[N*dim];
        for(int i=0; i<N; ++i){
           for(int j=0; j<dim; ++j) point[i*dim+j] = x[i*dim+j];
        }
        target = new float[N];
        for(int i=0; i<N; ++i) target[i] = t[i];
        E = new float[N];
        for(int i=0; i<N; ++i) E[i] = -target[i];
        float threshold = 0;
        int numChanged = 0;
        int examineAll = 1;
        while(numChanged>0 || examineAll){
            numChanged = 0;
            if(examineAll){
                for(int i=0; i<N; ++i) numChanged += examineExample(i);
            }
            else{
                for(int i=0; i<N; ++i){
                    if(a[i]!=0 && a[i]!=C) numChanged += examineExample(i);
                }
            }
            if(examineAll == 1){
                examineAll = 0;
            }
            else if(numChanged == 0){
                examineAll = 1;
            }
        }
    }
};
#endif

使い方

2次元ガウス分布test.csvに対して使うコードがこの通りです
f:id:kivantium:20150624142524p:plain:w300 (test.csvの分布)

#include <iostream>
#include <cstdio>
#include "svm.h"
using namespace std;

int main(void){
    // number of test data
    const int sample = 100;
    // size (dimension) of input vector
    const int size = 2;
    
    // create SVM (dimension is size)
    svm detector(size);
  
    // train data
    float x[size*sample];
    // label data
    int t[sample];
    
    // load CSV
    FILE *fp = fopen("test.csv", "r");
    if(fp==NULL) return -1;
    for(int i=0; i<sample; i++){
        // load data
        for(int j=0; j<size; j++) fscanf(fp, "%f,", x+size*i+j);
        // load label
        fscanf(fp, "%d", t+i);
    }
   
    // train SVM 
    detector.train(x, t, sample);
    // show the result
    int correct = 0;
    for(int i=0; i<sample; i++){
       if(detector.predict(x+size*i)>0){
           if(detector.test(i)==1) correct++;
       }else{
           if(detector.test(i)==-1) correct++;
       }
    }
    cout << (correct*100.0/sample) << "%" << endl;
    
    return 0;
}

実行すると100%分類できたことが分かります。

コードにはGitHubにも上げてあります。

以上SVMを実装した話でした。