読者です 読者をやめる 読者になる 読者になる

kivantium活動日記

プログラムを使っていろいろやります

オンライン学習アルゴリズムAROWについて

機械学習 論文紹介

以前AROWを使った記事を書きました(Caffeによる特徴抽出+AROWによる分類を試した - kivantium活動日記)が、AROWの内部を全然知らずに使っていたので勉強しました。その記録です。

ここの記述が様々なオンライン学習アルゴリズムの特徴をうまく説明していたので参考にするといいかもしれません。kazoo04.hatenablog.com

オンライン機械学習についてはこういう本が出ていて評判も良いです。(ここで唐突に貼られるAmazonへのリンク)

オンライン機械学習 (機械学習プロフェッショナルシリーズ)

オンライン機械学習 (機械学習プロフェッショナルシリーズ)

オンライン学習アルゴリズムとは

SVMなどの機械学習アルゴリズムでは分類を行う前に全ての学習データについて訓練を行います。しかし、学習データが大規模なときは全てのデータについて一度に学習を行うのが時間やメモリの都合上難しくなります。また気温変化や株価の解析など、データが順番に手に入るような状況ではデータ一つ一つについて学習を行って現在の学習能力での予想値を出すことが出来れば便利です。このような、データが一つ与えられるごとにパラメータを更新して学習することができるアルゴリズムをオンライン学習(逐次学習)アルゴリズムと呼びます。オンライン学習アルゴリズムにはパーセプトロン, CW, AROW, SCWなどがあります。

以下の説明はAdaptive Regularization of Weight Vectorsを参考にしました。

CW (Confidence Weighted Online Learning)

AROWの説明の前にAROWの前身となったCWを説明します。ここでは2値分類版のCWを取り上げます。

CWではD次元ベクトルxを入力として受け取り、重みベクトルwをかけた値w\cdot xの符号で分類を行います。

ここでwの分布にガウス分布\mathcal{N}(\mu, \Sigma) を導入するのがCWのポイントです。分布の平均\muはknowledgeを、\Sigmaはconfidenceを表しています。実際に予測を行うときは\muの値のみをそのままwの値として使うようです。

CWはオンライン学習アルゴリズムなのでデータを一つ与えるたびにパラメータwを更新していきます。CWでの学習のポイントは2点です。

  • 与えられたデータについて正しい分類を行う確率が\etaよりも高くなるようにする。ここで\eta0.5<\eta\leq1を満たす値です。
  • 更新を行う際はなるべく以前の分布に近い分布を保つ。分布の近さの指標にはKLダイバージェンスを使います。

これを式にすると、CWの更新アルゴリズム
{\displaystyle (\mu_t, \Sigma_t) = \min_{\mu, \Sigma} D_{KL}(\mathcal{N}(\mu, \Sigma)||\mathcal{N}(\mu_{t-1}, \Sigma_{t-1})
)\\
\mathrm{s.t.} Pr_{w\sim \mathcal{N}(\mu, \Sigma)}[y_t(w\cdot x)\geq 0 ]\geq \eta}
と表されます。

この更新式は解析的に解くことが出来るので順次更新することができます。

AROW (Adaptive Regularization of Weight Vectors)

CWでは更新するときに必ず正しい分類を行う確率が0.5を超えるようにするというアルゴリズムの性質上、誤ったラベルが付けられていたときにそのラベルについても正しい分類を行うように極端な学習を行ってしまい精度が落ちるという問題点があります。また、\Sigmaは誤ったラベルであっても単調減少する問題もあります。これらの弱点を克服するために考案されたのがAROWです。

AROWではデータが与えられるたびに次の目的関数を最小化します。
{\displaystyle \mathcal{C}(\mu, \Sigma) = D_{KL}(\mathcal{N}(\mu, \Sigma)||\mathcal{N}(\mu_{t-1},\Sigma_{t-1}))+\lambda_1 l_{h^2}(y_t, \mu\cdot x_t)+\lambda_2 x_t^T\Sigma x_t}
ここで、l_{h^2}(y_t, \mu\cdot x)l_{h^2}(y_t, \mu\cdot x)=\max\{0, 1-y_t(\mu\cdot x_t)\})^2という形の損失関数です。また、\lambda_1, \lambda_2は正のパラメータです。

この目的関数は3つの項からなっています。

  • 1つ目の項はKLダイバージェンスを最小化することで分布が急激に変わることを防ぐ役割があります。現在の分布は今までに与えられたデータに対してはうまく分類できることが分かっているので分布の保存は有効だと考えられます。
  • 第2項は新しいパラメーターが現在のデータに対して損失関数を小さくできることを表します。
  • 第3項はデータが増えることでconfidenceが一般的には上がっていく(=\Sigmaが小さくなる)ことを表しています。

第2項と第3項がCWの弱点を補強するために新たに付け加えられた項です。


この目的関数は\muに依存する部分をC_1(\mu)\Sigmaに依存する部分をC_2(\Sigma)と分離できます。
パラメータの更新は次のように行います。
1.  \muを更新する( \mu_t = \mathop{\rm arg~min}\limits_{\mu} C_1(\mu))
2.  \muの更新で値が変わった場合のみ\Sigmaを更新する( \Sigma_t = \mathop{\rm arg~min}\limits_{\Sigma} C_2(\Sigma))

更新は偏微分を0として代入すれば解析的に解けます。

というわけで全体の更新アルゴリズムは次のようになります。
f:id:kivantium:20150719012715p:plain:w600

これをうまいこと実装してやればAROWの出来上がりですが、MochiMochiを見れば分かるので自分ではやりませんでした。

他クラス分類への応用

CWの他クラス分類への拡張方法はMulti-Class Confidence Weighted Algorithmsに書いてあります。この方法を応用してAROWも他クラス分類に拡張できるそうです。


が、論文を読んでも実装が思い浮かばなかったので修行してもう一回考えることにします。

眠くなってきたのでこのあたりで一度更新します。