kivantium活動日記

プログラムを使っていろいろやります

自動微分ライブラリJAXを用いた対称行列の固有値の微分

JAXという自動微分ライブラリが流行りそうな機運があるので遊んでみます。

github.com

インストール

READMEに書いてある通りにやりました。

pip install --upgrade pip
pip install --upgrade jax jaxlib 

基本的な使い方

The Autodiff Cookbook — JAX documentation を読んで下さい

固有値の最小化の例

ここまでとても雑な説明だったのは、固有値の自動微分が今回の記事のメインテーマだからです。固有値の勾配を求める需要なんてないだろうと思っていたのですが、シュレディンガー方程式固有値方程式なので、エネルギーの最小化をするためには固有値を最小化する必要があり、そのために固有値の勾配が欲しいという需要があるようです。(この論文では実際に自動微分を使ってエネルギーの最小化を行っています)

JAXでは一般の固有値の勾配はまだ実装されていないようなので、対称行列の固有値の勾配を使った例題を解くことにします。

\displaystyle{
A=\left(
    \begin{array}{cc}
      1 &x \\
      x & -1
    \end{array}
  \right)
}

として、行列Aの最大固有値が最小になるようにしてみます。(ここでは実数の範囲のみを考えることにします)固有方程式を解くと、固有値λは

\displaystyle{
\lambda = \pm\sqrt{x^2+1}
}

となるので、最大固有値の最小値はx=0のときに1.0になります

xを与えたときにAの最大固有値を返す関数をfunc最急降下法で最小化するプログラムは以下のようになります。

import jax.numpy as np
from jax.ops import index, index_update
from jax import grad, jit

@jit
def func(x):
    A = np.zeros((2, 2))
    A = index_update(A, index[0, 0], 1)
    A = index_update(A, index[0, 1], x)
    A = index_update(A, index[1, 0], x)
    A = index_update(A, index[1, 1], -1)
    w, v = np.linalg.eigh(A)
    return np.max(w).real

func_grad = jit(grad(func))
alpha = 0.1
x = 1.0
for _ in range(100):
    print("x={}, f(x)={}".format(x, func(x)))
    x -= alpha * func_grad(x)

print("min value: {} (x={})".format(func(x), x))

コメント

  • 関数に@jitをつけるとJITコンパイルが行われて実行が高速になります。jit(grad(func))のようにすると、勾配関数もJITコンパイルできます。
  • JAXで配列に添字でアクセスするとエラーになるので、代わりにindex_updateを使っています。詳細は🔪 JAX - The Sharp Bits 🔪 — JAX documentationを見てください。(さすがにこれはあまりに汚いのでもう少しいいやり方があるかもしれません)
  • 固有値を求めるためにjax.scipy.linalg.eighを使っています。この関数は本来エルミート行列用なので、固有値複素数で返ってきます。ここでは問題を実数の範囲に限定しているので実部を取っています。

実行結果は次のようになりました。

x=1.0, f(x)=1.4142135381698608
x=0.9292893409729004, f(x)=1.3651295900344849
x=0.8612160086631775, f(x)=1.3197321891784668
x=0.7959591150283813, f(x)=1.2781044244766235
x=0.7336825728416443, f(x)=1.2402782440185547
(中略)
x=6.46666157990694e-05, f(x)=1.0
x=5.819995567435399e-05, f(x)=1.0
x=5.237996083451435e-05, f(x)=1.0
x=4.714196620625444e-05, f(x)=1.0
x=4.242776776663959e-05, f(x)=1.0
x=3.818498953478411e-05, f(x)=1.0
min value: 1.0 (x=3.43664905813057e-05)

最急降下法がそれっぽく動いていることが確認できました。この程度の問題では自動微分を使うまでもないですが、問題がもう少し複雑になるとJITによる高速な自動微分のメリットを享受できます。 今回は最急降下法を使いましたが、scipyと組み合わせればBFGSなどのもう少し高度な最適化手法を使うことができます(参考: A brief introduction to JAX and Laplace’s method - anguswilliams91.github.io

固有値微分とは一体なんなのかという話もする必要があるのですが、まだよく理解できていないので今日はとりあえずここまでにして後日更新します。

メモ欄

added jvp rule for eigh, tests by levskaya · Pull Request #358 · google/jax · GitHub

広告コーナー