読者です 読者をやめる 読者になる 読者になる

kivantium活動日記

プログラムを使っていろいろやります

ボトムアップ型自動微分の実験

Deep Learningの関係で自動微分が脚光を浴びつつあるような気がしますが、自動微分を解説したページは少なくまだまだマイナーな分野だと思います。先日ようやく「アルゴリズムの自動微分と応用」を一通り眺めたのでいろいろ実験して遊んでいます。今日は自動微分のうち、ボトムアップ型自動微分をオペレータオーバーロードを用いて実現する方法について書きます。

自動微分とは

自動微分は、アルゴリズムによって定義された関数からその関数の偏導関数値を計算するアルゴリズムを導出するための技術です。一般的にはy=x^2のような関数が先にあって、その関数を計算するアルゴリズムやプログラムがあるというように考えますが、自動微分ではどちらかというとアルゴリズムが先にあってアルゴリズムが表現する関数が生まれるというような考え方をします。

プログラムで微分を扱う上でよく知られている技術には数値微分と数式微分があります。

数値微分

数値微分は、小さな値hを用いて{\displaystyle \frac{f(a+h)-f(a)}{h}}{\displaystyle \frac{f(a+h)-f(a-h)}{2h}}を計算することで微分の近似値を得る技術です。数値微分はあくまで微分の近似値を計算する技術ですが、自動微分ではアルゴリズムで定義された関数から解析的に得た偏導関数の値を計算します。また、数値微分では摂動幅hを設定する必要がありますが、hが大きすぎると近似精度が悪くなり、小さすぎると丸め誤差の影響で精度が悪くなるという問題があります。

この問題を理解するために簡単なプログラムを書きました。{\displaystyle x=\frac{\pi}{4}}における\sin(x)微分{\displaystyle \frac{f(a+h)-f(a)}{h}}を計算することで求めます。解析的には{\displaystyle \frac{\sqrt{2}}{2}}が解なので、この2倍を出力するようにしてよく知っている\sqrt{2}の値にどれくらい近づくかを見てみます。

#include <iostream>
#include <iomanip>
#include <cmath>

using namespace std;

int main(void){
    // 小数点以下15桁表示する
    cout << fixed << setprecision(15);
    
    // 刻み幅を小さくして数値微分
    double h = 0.1;
    for(int i=0; i<15; ++i) {
        cout << 2 * (sin(M_PI/4 + h) - sin(M_PI/4))/h << " (h=" << h << ")" << endl;
        h /= 10;
    }
}

結果は

1.341205945807979 (h=0.100000000000000)
1.407118983378419 (h=0.010000000000000)
1.413506219948735 (h=0.001000000000000)
1.414142849338607 (h=0.000100000000000)
1.414206491290315 (h=0.000010000000000)
1.414212855488372 (h=0.000001000000000)
1.414213490757987 (h=0.000000100000000)
1.414213568473599 (h=0.000000010000000)
1.414213635086980 (h=0.000000001000000)
1.414215411443819 (h=0.000000000100000)
1.414224293228016 (h=0.000000000010000)
1.414202088767524 (h=0.000000000001000)
1.414424133372449 (h=0.000000000000100)
1.421085471520200 (h=0.000000000000010)
1.554312234475219 (h=0.000000000000001)

のようになりました。大きすぎても小さすぎても良くなくて、h=0.000000010000000のときの結果が最も良いようです。

一般に、{\displaystyle \frac{f(a+h)-f(a)}{h}}の計算誤差をなるべく小さくするためには関数の計算値に含まれる計算誤差を\deltaとして{\displaystyle h \simeq 2\sqrt{\frac{\delta}{|f''(a)|}}}とするのがよいことが知られています。しかし、f'を求めるのにf''の値が必要になってしまうのでこの式を使うのは現実的ではありません。

また、入力変数が多い関数に対しては入力変数の数だけ関数を計算する必要があり計算量が増えてしまいますが、自動微分を使うことで計算量を抑えることができます。

数式微分

MathematicaMaximaのようなソフトでは数式を処理して偏導関数を求める数式微分を行いますが、自動微分では繰り返しや分岐を含むようなアルゴリズムに対しても偏導関数を求めるアルゴリズム偏導関数そのものではない)を求めます。

自動微分の種類

自動微分にはボトムアップ型(フォーワードモードとも)とトップダウン型(リバースモードとも)の2種類があります。どちらも微分の連鎖律を用いて関数を単純な演算に分解することで偏導関数値を求めますが、分解の方法が異なります。

ボトムアップ型自動微分

ボトムアップ型自動微分では一つの入力変数に対して、全ての中間変数の偏導関数値を計算していく手法です。

例としてWikipediaの自動微分のページにあるf(x_1, x_2)=x_1 x_2+\sin{x_1}の例を考えます。この関数の計算をステップごとに中間変数に保存するとすれば、

  •  w_1 \leftarrow x_1
  •  w_2 \leftarrow x_2
  •  w_3 \leftarrow w_1 \cdot w_2
  •  w_4 \leftarrow \sin{w_1}
  •  w_5 \leftarrow w_3 + w_4

のように5つの中間変数を用いて表すことができます。

ここで入力変数x_1について各中間変数の微分を計算すると、

  •  {\displaystyle \frac{\partial w_1}{\partial x_1} = 1}
  •  {\displaystyle \frac{\partial w_2}{\partial x_1} = 0}
  •  {\displaystyle \frac{\partial w_3}{\partial x_1}= \frac{\partial (w_1 \cdot w_2 )}{\partial x_1} = \frac{\partial w_1}{\partial x_1}\cdot w_2 + w_1 \cdot \frac{\partial w_2}{\partial x_1}}
  •  {\displaystyle \frac{\partial w_4}{\partial x_1} = \frac{\partial \sin{w_1}}{\partial x_1}  = \frac{\partial w_1}{\partial x_1} \cdot \cos{w_1}}
  •  {\displaystyle \frac{\partial w_5}{\partial x_1} = \frac{\partial (w_3+w_4)}{\partial x_1} = \frac{\partial w_3}{\partial x_1} + \frac{\partial w_4}{\partial x_1}}

のように分解することができます。これは関数値を求める際に必要な各中間変数の計算と同時に行うことができます。

計算グラフを使って表すと以下のようになります。(Wikipediaから引用)
f:id:kivantium:20160325000216p:plain:w600

トップダウン型自動微分

トップダウン型自動微分では一つの出力変数について、全ての中間変数に対する偏導関数値を計算していく手法です。

先ほどと同じ例を考えます。出力変数w_5について各中間変数に対する微分を計算すると、

  •  {\displaystyle \frac{\partial w_5}{\partial w_5} = 1}
  •  {\displaystyle \frac{\partial w_5}{\partial w_4} = \frac{\partial w_5}{\partial w_5}\cdot \frac{\partial w_5}{\partial w_4} = \frac{\partial w_5}{\partial w_5}\cdot\frac{\partial (w_3+w_4)}{\partial w_4}=\frac{\partial w_5}{\partial w_5}}
  •  {\displaystyle \frac{\partial w_5}{\partial w_3} = \frac{\partial w_5}{\partial w_5}\cdot \frac{\partial (w_3+w_4)}{\partial w_3} = \frac{\partial w_5}{\partial w_5}}
  •  {\displaystyle \frac{\partial w_5}{\partial w_2} = \frac{\partial w_5}{\partial w_3} \cdot \frac{\partial w_3}{\partial w_2} = \frac{\partial w_5}{\partial w_3} \cdot \frac{\partial (w_1 \cdot w_2)}{\partial w_2} = \frac{\partial w_5}{\partial w_3}\cdot w_1}
  •  {\displaystyle \frac{\partial w_5}{\partial w_1} = \frac{\partial w_5}{\partial w_3} \cdot \frac{\partial w_3}{\partial w_1} + \frac{\partial w_5}{\partial w_4} \cdot \frac{\partial w_4}{\partial w_1}= \frac{\partial w_5}{\partial w_3} \cdot \frac{\partial (w_1 \cdot w_2)}{\partial w_1}+ \frac{\partial w_5}{\partial w_4} \cdot \frac{\partial \sin{w_1}}{\partial w_1} = \frac{\partial w_5}{\partial w_3} \cdot w_2 + \frac{\partial w_5}{\partial w_4} \cdot \cos{w_1}}

のように分解することができます。

w_1, \ w_2についての分解が分かりにくいですが、w_5の変化は中間変数w_3, \ w_4の変化を経由してしか起こらないことに注目すると納得できるかもしれません。計算グラフは以下の通りです。
f:id:kivantium:20160325002547p:plain:w600

トップダウン型自動微分では一度計算した後で各中間変数がどのように計算されたのかの履歴を逆向きに辿って偏導関数値を計算する必要があるため、計算グラフを保存するためのメモリが必要になります。

勾配を求める際には、入力変数の方が多いときはトップダウン型、出力変数の方が多い時はボトムアップ型を使うのが計算量の点で有利になります。ニューラルネットワークバックプロパゲーショントップダウン型の自動微分の一種です。

自動微分の実現方法

自動微分の実現方法には、ソースコードを解析して偏導関数値を求めるプログラムを作る方法とオペレータオーバーロードによる方法があります。ソースコードを解析する方法はコンパイラを作るのと同じような労力が必要になるので難しいですが、オペレータオーバーロードによる方法は比較的容易なので今回はこちらを採用します。

オペレータオーバーロードによるボトムアップ型自動微分

ボトムアップ型自動微分では関数値を計算するのと同時に偏導関数値を計算することができます。そのためdoubleを自動微分用に拡張したBUdoubleというクラスをつくって計算を行うと同時に偏導関数値を保存するようにします。それを実現したのが以下のコードです。

#include <iostream>
#include <cmath>
#include <iomanip>

using namespace std;
 
// ボトムアップ型微分積分用doubleクラス
class BU_double {
    // 変数値
    double val;
    // 偏導関数値
    double d_val;
public:
    // コンストラクタ
    BU_double(double v=0, double dv=0) {
        val = v;
        d_val = dv;
    }
    // 微分する入力変数として選択する関数
    void select(void) {
        d_val = 1.0;
    }
    // 変数値を返す
    double get_value(void) {
        return val;
    }
    // 偏導関数値を返す
    double get_d_value(void) {
        return d_val;
    }

    // 各種演算子の定義
    friend BU_double operator + (BU_double x, BU_double y) {
        return BU_double(x.val+y.val, x.d_val+y.d_val);
    }
    friend BU_double operator - (BU_double x, BU_double y) {
        return BU_double(x.val-y.val, x.d_val-y.d_val);
    }
    friend BU_double operator * (BU_double x, BU_double y) {
        return BU_double(x.val*y.val, x.d_val*y.val+x.val*y.d_val);
    }
    friend BU_double operator / (BU_double x, BU_double y) {
        double w = x.val/y.val;
        return BU_double(w, (x.d_val-w*y.d_val)/y.val);
    }
    friend BU_double operator + (BU_double x) {
        return BU_double(x.val, x.d_val);
    }
    friend BU_double operator - (BU_double x) {
        return BU_double(-x.val, -x.d_val);
    }
    friend bool operator < (BU_double x, BU_double y) {
        return x.val < y.val;
    }
    friend bool operator <= (BU_double x, BU_double y) {
        return x.val <= y.val;
    }
    friend bool operator > (BU_double x, BU_double y) {
        return x.val > y.val;
    }
    friend bool operator >= (BU_double x, BU_double y) {
        return x.val >= y.val;
    }
    
    // 基本関数の定義
    friend BU_double sqrt(BU_double x) {
        return BU_double(sqrt(x.val), 0.5*x.d_val/sqrt(x.val));
    }
    friend BU_double exp(BU_double x) {
        return BU_double(exp(x.val), x.d_val*exp(x.val));
    }
    friend BU_double log(BU_double x) {
        return BU_double(log(x.val), x.d_val/x.val);
    }
    friend BU_double sin(BU_double x) {
        return BU_double(sin(x.val), cos(x.val));
    }
    friend BU_double cos(BU_double x) {
        return BU_double(cos(x.val), -sin(x.val));
    }

    // coutに出力するフォーマットの定義
    friend ostream& operator<<(ostream &s, BU_double x) {
        return s << "BU_double("<< x.val << ", " << x.d_val << ") ";
    }

};

int main(void){
    cout << setprecision(15);
    BU_double x, y;

    // sin(x)の微分 
    x = M_PI / 4;
    x.select();
    y = sin(x);
    cout << 2 * y.get_d_value() << endl; // 2016-07-13に修正

    
    // y = 10*x^2 のfor文を使ったアホな定義
    x = 5;
    y = 0;
    x.select();
    for(int i=0; i<10; ++i) y = y + x*x;
    // x = 5.0のときのyとdy/dxを出力
    cout << y << endl;

    // ニュートン法の実行
    x = 2.0; // 適当な初期値
    for(int i=0; i<10000; ++i) {
        x.select();
        // y = (x-sqrt(2))*(x^3+1)
        y = (x-sqrt(2))*(x*x+x+1);
        // yの値がdoubleの丸め誤差と同程度に小さければ終了
        if (y.get_value() < 1e-15) break;
        // ニュートン法の更新式で次のxを求める
        x = x - y.get_value()/y.get_d_value();
    }
    // x(=sqrt(2))を出力
    cout << x.get_value() << endl;
}

出力は以下の通りです。

1.41421356237309
BU_double(250, 100) 
1.4142135623731

出力の1行目は{\displaystyle \sin{x}}微分{\displaystyle x=\frac{\pi}{4}}のときの値の2倍(すなわち\sqrt{2})を表しています。最初に示した数値微分より高い精度で計算できていることが分かります。
2行目はy=10x^2という関数にx=5.0を入れて計算したときの関数値と導関数値です。for文を使った繰り返し文として定義しているにも関わらずきちんと導関数値が求まっていることが分かります
3行目はニュートン法を用いて求めた(x-\sqrt{2})(x^2+x+1)の解を表しています。ニュートン法f(x)=0の解を {\displaystyle x_{n+1} = x_n - \frac{f(x)}{f'(x)}}という更新式を用いて反復することで求めるというアルゴリズムです。導関数の計算が自動的に行われるため簡潔に記述することができます。


次回はトップダウン型自動微分を実装して実験してみようと思います。

参考文献

アルゴリズムの自動微分と応用 (現代非線形科学シリーズ)

アルゴリズムの自動微分と応用 (現代非線形科学シリーズ)

ここで使った自動微分クラスはこの本に書いてあるものをそのまま使っています。日本語で自動微分をテーマにした本はこの一冊しかなさそうです。Amazonの在庫は中古しかありませんが出版社のサイトから直接注文したら定価で入手することができました。在庫は少ないようです。


数値計算の常識

数値計算の常識

ニュートン法停止条件などで参考にしました。