JAXという自動微分ライブラリが流行りそうな機運があるので遊んでみます。
インストール
READMEに書いてある通りにやりました。
pip install --upgrade pip pip install --upgrade jax jaxlib
基本的な使い方
The Autodiff Cookbook — JAX documentation を読んで下さい
固有値の最小化の例
ここまでとても雑な説明だったのは、固有値の自動微分が今回の記事のメインテーマだからです。固有値の勾配を求める需要なんてないだろうと思っていたのですが、シュレディンガー方程式が固有値方程式なので、エネルギーの最小化をするためには固有値を最小化する必要があり、そのために固有値の勾配が欲しいという需要があるようです。(この論文では実際に自動微分を使ってエネルギーの最小化を行っています)
JAXでは一般の固有値の勾配はまだ実装されていないようなので、対称行列の固有値の勾配を使った例題を解くことにします。
として、行列Aの最大固有値が最小になるようにしてみます。(ここでは実数の範囲のみを考えることにします)固有方程式を解くと、固有値λは
となるので、最大固有値の最小値はx=0のときに1.0になります
xを与えたときにAの最大固有値を返す関数をfunc
を最急降下法で最小化するプログラムは以下のようになります。
import jax.numpy as np from jax.ops import index, index_update from jax import grad, jit @jit def func(x): A = np.zeros((2, 2)) A = index_update(A, index[0, 0], 1) A = index_update(A, index[0, 1], x) A = index_update(A, index[1, 0], x) A = index_update(A, index[1, 1], -1) w, v = np.linalg.eigh(A) return np.max(w).real func_grad = jit(grad(func)) alpha = 0.1 x = 1.0 for _ in range(100): print("x={}, f(x)={}".format(x, func(x))) x -= alpha * func_grad(x) print("min value: {} (x={})".format(func(x), x))
コメント
- 関数に
@jit
をつけるとJITコンパイルが行われて実行が高速になります。jit(grad(func))
のようにすると、勾配関数もJITコンパイルできます。 - JAXで配列に添字でアクセスするとエラーになるので、代わりに
index_update
を使っています。詳細は🔪 JAX - The Sharp Bits 🔪 — JAX documentationを見てください。(さすがにこれはあまりに汚いのでもう少しいいやり方があるかもしれません) - 固有値を求めるために
jax.scipy.linalg.eigh
を使っています。この関数は本来エルミート行列用なので、固有値が複素数で返ってきます。ここでは問題を実数の範囲に限定しているので実部を取っています。
実行結果は次のようになりました。
x=1.0, f(x)=1.4142135381698608 x=0.9292893409729004, f(x)=1.3651295900344849 x=0.8612160086631775, f(x)=1.3197321891784668 x=0.7959591150283813, f(x)=1.2781044244766235 x=0.7336825728416443, f(x)=1.2402782440185547 (中略) x=6.46666157990694e-05, f(x)=1.0 x=5.819995567435399e-05, f(x)=1.0 x=5.237996083451435e-05, f(x)=1.0 x=4.714196620625444e-05, f(x)=1.0 x=4.242776776663959e-05, f(x)=1.0 x=3.818498953478411e-05, f(x)=1.0 min value: 1.0 (x=3.43664905813057e-05)
最急降下法がそれっぽく動いていることが確認できました。この程度の問題では自動微分を使うまでもないですが、問題がもう少し複雑になるとJITによる高速な自動微分のメリットを享受できます。 今回は最急降下法を使いましたが、scipyと組み合わせればBFGSなどのもう少し高度な最適化手法を使うことができます(参考: A brief introduction to JAX and Laplace’s method - anguswilliams91.github.io)
固有値の微分とは一体なんなのかという話もする必要があるのですが、まだよく理解できていないので今日はとりあえずここまでにして後日更新します。
メモ欄
added jvp rule for eigh, tests by levskaya · Pull Request #358 · google/jax · GitHub
QR法とかの計算グラフ全部構築して勾配計算してるのかなと思ったけど解析的に求まるのか(がっかり) https://t.co/WHdvqvmllt
— きばん (@kivantium) 2020年6月21日
https://t.co/lDDFRVNhdy 付け加えるなら固有ベクトルの元行列による微分はDrazin generalized inverse使って求まりますってのと自分は StewartのMatrix Perturbation Theoryという黒本がお勧め(?)ですってのぐらいかな
— chunjp (@chunjp) 2020年6月22日