kivantium活動日記

プログラムを使っていろいろやります

自動微分ライブラリJAXを用いた対称行列の固有値の微分

JAXという自動微分ライブラリが流行りそうな機運があるので遊んでみます。

github.com

インストール

READMEに書いてある通りにやりました。

pip install --upgrade pip
pip install --upgrade jax jaxlib 

基本的な使い方

The Autodiff Cookbook — JAX documentation を読んで下さい

固有値の最小化の例

ここまでとても雑な説明だったのは、固有値の自動微分が今回の記事のメインテーマだからです。固有値の勾配を求める需要なんてないだろうと思っていたのですが、シュレディンガー方程式固有値方程式なので、エネルギーの最小化をするためには固有値を最小化する必要があり、そのために固有値の勾配が欲しいという需要があるようです。(この論文では実際に自動微分を使ってエネルギーの最小化を行っています)

JAXでは一般の固有値の勾配はまだ実装されていないようなので、対称行列の固有値の勾配を使った例題を解くことにします。

\displaystyle{
A=\left(
    \begin{array}{cc}
      1 &x \\
      x & -1
    \end{array}
  \right)
}

として、行列Aの最大固有値が最小になるようにしてみます。(ここでは実数の範囲のみを考えることにします)固有方程式を解くと、固有値λは

\displaystyle{
\lambda = \pm\sqrt{x^2+1}
}

となるので、最大固有値の最小値はx=0のときに1.0になります

xを与えたときにAの最大固有値を返す関数をfunc最急降下法で最小化するプログラムは以下のようになります。

import jax.numpy as np
from jax.ops import index, index_update
from jax import grad, jit

@jit
def func(x):
    A = np.zeros((2, 2))
    A = index_update(A, index[0, 0], 1)
    A = index_update(A, index[0, 1], x)
    A = index_update(A, index[1, 0], x)
    A = index_update(A, index[1, 1], -1)
    w, v = np.linalg.eigh(A)
    return np.max(w).real

func_grad = jit(grad(func))
alpha = 0.1
x = 1.0
for _ in range(100):
    print("x={}, f(x)={}".format(x, func(x)))
    x -= alpha * func_grad(x)

print("min value: {} (x={})".format(func(x), x))

コメント

  • 関数に@jitをつけるとJITコンパイルが行われて実行が高速になります。jit(grad(func))のようにすると、勾配関数もJITコンパイルできます。
  • JAXで配列に添字でアクセスするとエラーになるので、代わりにindex_updateを使っています。詳細は🔪 JAX - The Sharp Bits 🔪 — JAX documentationを見てください。(さすがにこれはあまりに汚いのでもう少しいいやり方があるかもしれません)
  • 固有値を求めるためにjax.scipy.linalg.eighを使っています。この関数は本来エルミート行列用なので、固有値複素数で返ってきます。ここでは問題を実数の範囲に限定しているので実部を取っています。

実行結果は次のようになりました。

x=1.0, f(x)=1.4142135381698608
x=0.9292893409729004, f(x)=1.3651295900344849
x=0.8612160086631775, f(x)=1.3197321891784668
x=0.7959591150283813, f(x)=1.2781044244766235
x=0.7336825728416443, f(x)=1.2402782440185547
(中略)
x=6.46666157990694e-05, f(x)=1.0
x=5.819995567435399e-05, f(x)=1.0
x=5.237996083451435e-05, f(x)=1.0
x=4.714196620625444e-05, f(x)=1.0
x=4.242776776663959e-05, f(x)=1.0
x=3.818498953478411e-05, f(x)=1.0
min value: 1.0 (x=3.43664905813057e-05)

最急降下法がそれっぽく動いていることが確認できました。この程度の問題では自動微分を使うまでもないですが、問題がもう少し複雑になるとJITによる高速な自動微分のメリットを享受できます。 今回は最急降下法を使いましたが、scipyと組み合わせればBFGSなどのもう少し高度な最適化手法を使うことができます(参考: A brief introduction to JAX and Laplace’s method - anguswilliams91.github.io

固有値微分とは一体なんなのかという話もする必要があるのですが、まだよく理解できていないので今日はとりあえずここまでにして後日更新します。

メモ欄

added jvp rule for eigh, tests by levskaya · Pull Request #358 · google/jax · GitHub

2020年5月

出来事

インターンの話

4月の記事に書いたようにインターンの内容に不満があったので、ゴールデンウィーク明けに今の業務をつづける意味を感じないという話を上司にしたところ、とりあえず出社が可能になるまでは休職することになりました。その時は休職は明日からでも良いという話でしたが、事務手続きに2週間ほどかかったので実際に休職になったのは先週からです。休職した分インターン期間を変更するなどの配慮をしてもらいましたが、今後どうなるかは状況次第です。 今までの経験を振り返っても、何も分からない担当者と働いてひどい目にあったときはともかくとして、特に業務内容や人間関係に問題がないときでも逃げるように辞めていたような気がしますし、上手に労働ができた記憶がありません。これからどうやって生きていけばいいんだろうなぁ。

Webサービスの話

業務後の時間を使ってひたすらWebサービスを作っていました。

nijisearch.kivantium.net

自分からイラストを探しにいかなくても勝手にイラストが集まってくるのは見ていて楽しいので作った意味があったと思っています。 サーバー代くらいは稼げるといいなと思ってGoogle Adsenseの広告を貼っていますが、今のPVでは全く利益が出ないので世の中の厳しさを感じています。自分が欲しいものを作っただけなので別に儲からなくてもいいのですが。 このサービスの本質はタイムラインからのイラスト自動収集だけで残りはだいたいやるだけなのですが、作業の量としては圧倒的にやるだけ部分が多くなっています。世の中ってそういうものよね。勉強しないといけないことがたくさんあるので、あまりこれに時間を使ってはいけないと思っているのですが、手を動かした分だけ機能が増えるのは楽しいのでつい弄ってしまってよくないです。

読んだ本

業務時間と趣味の開発で時間を取られて他のことをする時間がだいぶ減っていました。

統計力学〈1〉 (新物理学シリーズ)

統計力学〈1〉 (新物理学シリーズ)

統計力学を使ってニューラルネットワークを解析するみたいな話を聞いて解析力学を勉強しようと思って学部生のときに買ったものの積んでいた本です。等重率の原理から出発して熱力学のいろんな現象が出てくる話が中心だったのですが、熱力学にあまり思い入れがないこともあってか前評判ほどの面白さは感じられませんでした。

専門ではないが知っておきたい分野の本をどのくらいしっかり読み込めばいいのか分からないでいます。というか、勉強の仕方が未だに分かりません。ここ数年は本がなくても内容が復元できるようなノートを書きながら本を読み進めて、一冊分ノートを書き終わったらおしまいという勉強をしているのですが、時間の割にあまり身についていない気がします。ただ読むだけだと身につかないですが、だからといって何も見ずに内容を再現できるレベルで理解するようにしていたらいつになっても読み終わらないわけで、その中間を模索しています。

山田エルフ大先生がかわいかったです。

なもり先生の絵がかわいかったです。

知能がない。

Django ChannelsでWebsocket通信を行ってVue.jsで表示する

前回の続きです。 kivantium.hateblo.jp

サーバー上で重たい処理をする場合、処理が全て終わってからHTMLを生成しているとかなりレスポンスが悪くなってしまいます。先に結果表示ページのHTMLを表示しておいて、中身を後からWebsocketで順次通信するといい感じになりそうです。Websocket通信したデータを表示するためにVue.jsを使うことにします。

前回からの変更点

前回のHello, worldで使ったディレクトリ構成を編集する形で作業を進めます。

mysite/routing.py を以下の内容で書き換えます。ws/test/にアクセスすることでWebsocket通信を行えるように設定しています。

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path

from . import consumers

websocket_urlpatterns = [
    path('ws/test/', consumers.Consumer),
]

application = ProtocolTypeRouter({
    # (http->django views is added by default)
    'websocket': AuthMiddlewareStack(
        URLRouter(
            websocket_urlpatterns
        )
    ),
})

mysite/consumers.py にWebsocket通信で行う処理を記述します。本当は非同期処理をするべきらしいのですが、動く書き方が分からなかったので同期処理で書いています。(参考: django で Websocket - 空のブログ

import json
import time
import threading

from channels.generic.websocket import WebsocketConsumer

class Consumer(WebsocketConsumer):
    # 接続されたときの処理
    def connect(self):
        # 接続を許可する
        self.accept()
        # メッセージ送信関数を新しいスレッドで呼び出す
        # 本当は非同期で書くべきだがうまく動かなかった
        self.sending = True
        self.sender = threading.Thread(
            target=self.send_message, args=('Hello', ))
        self.sender.start()

    # 接続が切断されたときの処理
    def disconnect(self, close_code):
        # スレッドを終了するフラグを立てる
        self.sending = False
        # スレッドの終了を待つ
        self.sender.join()

    # メッセージを受け取ったときの処理
    def receive(self, text_data):
        text_data_json = json.loads(text_data)
        message = text_data_json['message']
        print("Received message:", message)

    # メッセージ送信関数
    def send_message(self, message):
        while True:
            # 終了フラグが立っていたら終了する
            if not self.sending:
                break
            self.send(text_data=json.dumps({
                'message': message,
            }))
            time.sleep(1)

mysite/templates/mysite/index.html にこれと通信するためのHTMLとJavaScriptを書きます。

<!doctype html>
<html>
  <head>
    <title>test</title>
    <script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
  </head>
  <body>
    <div id="app">
      <ul >
        <li v-for="message in messages">[[ message ]]</li>
      </ul>
    </div>
    <script>
      var vm = new Vue({
        el: '#app',
        // Vueの記号がDjangoと被らないようにする
        delimiters: ['[[', ']]'],
        data: {
          messages: [],
          ws: new WebSocket(
            // http or https の判定
            // https://www.koatech.info/blog/vue-websocket-sample/
            (window.location.protocol == "https:" ? "wss" : "ws")
            + '://' + window.location.host + '/ws/test/'
          ),
        },
        created: function() {
          // Vueオブジェクトにアクセスするために変数化する
          // https://www.koatech.info/blog/vue-websocket-sample/
          const self = this;

          // メッセージを受け取ったときの処理
          self.ws.onmessage = function(e) {
            const data = JSON.parse(e.data);
            console.log(data.message)
            console.log(vm.messages)
            // メッセージを表示する
            vm.messages.push(data.message);
            // サーバーにメッセージを送る
            self.ws.send(JSON.stringify({
              'message': "Konnichiwa",
            }));
          };

          // 接続が切断されたときの処理
          self.ws.onclose = function(e) {
            console.error('Websocket has been closed unexpectedly');
          };
        }
      })
    </script>
  </body>
</html>

mysite/views.py を次のように書き換えてindex.htmlを表示するようにします。

from django.shortcuts import render

def index(request):
    return render(request, 'mysite/index.html')

最後に、テンプレートを認識させるために mysite/settings.pyINSTALLED_APPSmysite を追加します。

INSTALLED_APPS = [
    ......
    'channels',
    'mysite',
]

以上の変更を行ったものをデプロイしてブラウザで閲覧すると1秒ごとにHelloが追加される様子が確認できます。

f:id:kivantium:20200502150709g:plain
スクリーンショット

自分の目的にはこれで十分でしたが、チャットなどのより高度なことをしたい場合はChennelsのドキュメントを読んで下さい。 channels.readthedocs.io

Django ChennelsアプリをNginxとSupervisorでデプロイする

DjangoでWebsocketを使うときにはChannelsというライブラリがよく使われています。これまではHerokuにデプロイをしてきましたが、HerokuとChannelsの相性が良くないのかすぐに接続が切れてしまうので、これからはAWS上で開発しようと思いました。公式ドキュメントを読んでもデプロイ方法がよく分からなかったのでメモしておきます。

AWS LightsailでUbuntu 18.04のインスタンスを立てたとして、SSHで入ってからHello, world!するところまでを見ていきます。

ライブラリのインストール

Daphneの起動を楽にするためにvenvを使います。(参考: Djangoのインストール · Django Girls Tutorial

ssh ubuntu@xxx.xxx.xxx.xxx
sudo apt update
sudo apt install python3-venv
python3 -m venv env
source env/bin/activate

requirements.txtに以下を記述します。

django~=3.0.5
channels~=2.4.0

pipをアップデートしてから必要なライブラリをインストールします。

python -m pip install --upgrade pip
pip install -r requirements.txt

Hello, world!アプリの設定

Django プロジェクトを作成します。

django-admin startproject mysite
cd mysite

基本的にはChannels公式ドキュメントのInstallationに従って設定します。簡単のために本番環境と開発環境の設定の分離などは無視します。

mysite/settings.py を以下のように編集します。diffを示しています。

-ALLOWED_HOSTS = []
+ALLOWED_HOSTS = ['*']  # 本当は適切なホストを指定するべきだが簡単のため全て許可
 
(略)

INSTALLED_APPS = [
     'django.contrib.sessions',
     'django.contrib.messages',
     'django.contrib.staticfiles',
+    'channels',
 ]
 
+ASGI_APPLICATION = "mysite.routing.application"
+
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
     'django.contrib.sessions.middleware.SessionMiddleware',

mysite/routing.py を以下の内容で作成します。

from channels.routing import ProtocolTypeRouter

application = ProtocolTypeRouter({
    # Empty for now (http->django views is added by default)
})

mysite/urls.py を以下の内容で作成します。

from django.urls import path

from . import views

urlpatterns = [
    path('', views.index, name='index'),
]

mysite/views.py を以下の内容で作成します。

from django.http import HttpResponse


def index(request):
    return HttpResponse("Hello, world!")

ここまで来たら python manage.py runserver を実行してエラーが出ないことだけ確認します。(サーバー上にあるのでこの時点ではブラウザで表示確認ができません)

NginxとSupervisorの設定

ここもChannels公式ドキュメントのDeployingに従って設定するだけなのですが、この通りにやっても動かなかったので以下のStackOverflowに従ってアレンジしました。

stackoverflow.com

まずはNginxとSupervisorをインストールします。

sudo apt install nginx supervisor

mysite/asgi.pyを以下の内容で作成します。

"""
ASGI entrypoint. Configures Django and then runs the application
defined in the ASGI_APPLICATION setting.
"""

import os
import django
from channels.routing import get_default_application

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings")
django.setup()
application = get_default_application()

/etc/supervisor/conf.d/asgi.confを以下の内容で作成します。

[fcgi-program:asgi]
# TCP socket used by Nginx backend upstream
socket=tcp://localhost:8000

# Directory where your site's project files are located
directory=/home/ubuntu/mysite

# Each process needs to have a separate socket file, so we use process_num
# Make sure to update "mysite.asgi" to match your project name
command=/home/ubuntu/env/bin/daphne --fd 0 --access-log - --proxy-headers mysite.asgi:application
# Number of processes to startup, roughly the number of CPUs you have
numprocs=4

# Give each process a unique name so they can be told apart
process_name=asgi%(process_num)d

# Automatically start and recover processes
autostart=true
autorestart=true

# Choose where you want your log to go
stdout_logfile=/var/log/asgi.log
redirect_stderr=true

設定を読み込みます。

sudo supervisorctl reread
sudo supervisorctl update

/etc/nginx/sites-available/defaultを以下のように編集します。

upstream channels-backend {
    server localhost:8000;
}
...
server {
    ...
    location / {
        try_files $uri @proxy_to_app;
    }
    ...
    location @proxy_to_app {
        proxy_pass http://channels-backend;

        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        proxy_redirect off;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Host $server_name;
    }
    ...
}

設定を読み込みます。

sudo service nginx reload

ブラウザのアドレス欄にこのサーバーのIPアドレスを入力すれば、Hello, world! と出力されたページを確認することができます。 次回はこのサーバーを使って簡単なWebsocketを使ったプログラムを書きます。

2020年4月

4月は死んだ、もういない!

出来事

インターンについて

まともに労働したのが学部4年の夏休みにインターンしたとき以来だったので記憶が薄れていたのですが、やっぱり労働は厳しいです。時給制で働く場合はどれだけ頑張っても働いた時間に対してしか賃金が支払われないので、頑張らずにゆっくり働くほうが自分にとって得になってしまいます。また、昼寝などをして一時的に職場を離れたほうが結果的に一日のパフォーマンスが上がる状況であっても、昼寝に対しては賃金が発生しません。その結果、眠い目をこすりながら動かない頭で働き続けることになり、非効率な労働をだらだらと続けることになります。非効率な行動を取ることが経済的に得であるという状況が発生してしまう労働という構造が正しくないことに気づいた後、その状況に甘んじている自分が嫌になるというのが労働が嫌になるときのパターンな気がします。業務委託契約で働いていたときは、仕様を満たすものが作れないと何も賃金がもらえないというプレッシャーがある一方、最低限のものを作って納品すればいいだろうという雑な態度になっている自分が嫌になっていた記憶があります。

業務内容が面白くても労働という構造をはさんだ瞬間につらさが発生するのですが、業務内容が面白くない場合はなおさらつらくなります。今回のインターンの目的はロボットの開発技術を学ぶことだったのですが、与えられた仕事はその会社で開発しているソフトの上でロボットの動かし方を考えるというもので、退社した瞬間に役に立たなくなりそうな技術ばかり学んでいます。さらに、コロナウィルスの影響で出社3日目から在宅勤務になっているため、自分が考えたロボットの動かし方を実機で試すこともできず、本当に何をやっているのか分からない状態になっています。だからといって別の仕事をよこせと言える状況でもないので毎日のように退職を考えています。SHIROBAKO 9話でこのままタイヤを作る仕事を続けていいのか悩むみーちゃんの心境がとてもよく分かるようになりました。

Webサービス開発について

このまま労働を続けていても何にもならないと思ったので、業務後の時間を使ってWebサービス開発の勉強をすることにしました。題材は何でも良かったのですが、自分が欲しいサービスを作るのが良いだろうということでTwitter上の画像検索サービスを作っています。(画像の元ネタ検索に便利だった TwiGaTen というサービスが終了していたのが思いついたきっかけですが、いま確認したら再開していました……。サーベイを怠って失敗するのはいつものことです。)

せっかくなので機械学習と絡めようと思って、タイムラインからイラストを含むツイートのみを選んで表示する機能を最初に実装しました。http://kivantium-playground.herokuapp.com/ で公開していますが、依存しているライブラリがHeroku上での動作に問題があるらしく、すぐ落ちてしまいます。タイムラインから画像を自動収集する機能などを実装したらAWSに移行しようと思っています。いずれはレコメンデーションや類似画像検索等も実装していきたいと思っていますが、そこまでやる気が続くかは分かりません。

読んだ本

インターンで通勤時間が発生したらたくさん本が読めると思っていたのですが、在宅勤務になったので結局あまり読めませんでした。

無能なナナがアニメ化することになりました

進撃の巨人(31) (講談社コミックス)

進撃の巨人(31) (講談社コミックス)

  • 作者:諫山 創
  • 発売日: 2020/04/09
  • メディア: コミック

化合物でもDeep Learningがしたい!

この記事は2017年12月15日に https://kivantium.net/deep-for-chem/ に投稿したものです。 情報が古くなっていますが、まだ参照されているようなので再掲します。

この記事はDeep Learningやっていき Advent Calendar 2017の15日目です。

Deep Learningの威力が有名になったのは画像認識コンテストでの圧勝がきっかけでしたが、今ではDeep Learningはあらゆる分野に応用され始めています。NIPS2017でもMachine Learning for Molecules and Materialsが開催されたように、物質化学における機械学習の存在感が高まりつつあります。この記事ではその一例として化学の研究にDeep Learningが使われている例を紹介していきます。

化学物質の研究に機械学習が使われる主なパターンには

  • 分子を入力するとその分子の性質を出力する
  • 分子の性質を入力するとその性質を持った分子を出力する
  • 分子を入力するとその反応を出力する

の3つがあります。それぞれについて詳しく説明します。

分子から性質を予測する

Deep Learning以前

Deep Learning以前の性質予測では、職人の温かい手作りによる特徴量が使われていました。分子の特徴ベクトルはmolecular fingerprintsと呼ばれます。molecular fingerprintsは化合物の特徴的な一部分(fragmentと呼ばれる)がその分子にあるかどうかを0/1で表したbitを並べて作られます。

(画像はFingerprints in the RDKit p.4より引用)

どのfragmentを用いるのが有効かはデータセット・問題に依存するので様々な種類のfingerprintが提案されてきました。

主なfingerprintを挙げると

などがあります。 fingerprintはRDKitなどのライブラリを使うと簡単に計算できます。(各ソフトで計算できるfingerprintのリスト

このようなfingerprintを使ってSVMやRandomforestなどでその分子がある性質を持つ/持たないを予測する研究がたくさんあります。化学の分野でDeep Learningが大きく注目されるきっかけになったのは、kaggleの薬の活性予測のコンペでHintonらのチームが優勝したことですが、論文を見ると特徴量には上のように設計されたものを使っており、ニューラルネットワークで設計されたものではなかったようです。

graph convolutionの登場

fingerprintの設計にニューラルネットワークが導入されたのが[Duvenaud+, 2015]です。この研究ではcircular fingerprint (上のECFPのこと)をもとにneural graph fingerprint (NFP)を提案しています。以下にアルゴリズムを示します。

従来のfingerprint設計でhashやmodになっていた部分が重みを調整できる演算に変更されています。これにより、予測にとって重要なfragmentの寄与は大きく、重要ではないfragmentの寄与は小さくなるような特徴量が設計できるようになりました。実際に分子の水への溶けやすさをNFPで予測したところ水への溶けやすさに影響するR-OHのような構造の重みが大きくなったことが報告されています。

NFPの他にも分子のグラフ構造に基づいたニューラルネットワークベースの特徴量設計の研究が行われています。これらはグラフ構造に注目したニューラルネットワークなので総称としてgraph convolutionと呼ばれています。一番有名なのはGoogle BrainのNeural Message Passing for Quantum Chemistryでしょう。この論文ではMessage Passing Neural Network (MPNN) というグラフ上のニューラルネットワークを提案し、分子のニューラルネットワークの先行研究の多くがMPNNで一般的に記述できることを主張した上で、MPNNが分子の性質を予測する上で高い性能を発揮すると主張しています。 MPNNは

という式で表されます。グラフ上で隣接するエッジからのメッセージ Mを足し合わせるような処理をしていることが分かります。 M, Uなどをうまく定めることで各種のgraph convolutionを表すことができます。詳細は論文を読んでください。Google ResearchによるブログPredicting Properties of Molecules with Machine Learningも役立つかもしれません。(ちなみにこの論文のラストオーサーは先述したkaggleコンペの論文のファーストオーサーです)

graph convolutionのイメージとしてよく使われる絵が[Han Altae-Tran+, 2017]にあります。一枚引用します。

Graph convolutionではない方法としては[Goh+, 2017]のような分子を画像にしてCNNで予測するようなものもあります。

ちなみに、同じ人がつい先日SMILES2Vecという文字列から化合物の性質を予測する論文も書いていました。

実装

分子に対するDeep Learningのライブラリで最も有名なのはDeepChemでしょう。DeepChemはTensorFlowでgraph convolutionを実装しています。Graph Convolutions For Tox21などのチュートリアルを読むとだいたい使い方が分かるのではないでしょうか(私も使ったことはないです)。ちなみに、なぜかPong in DeepChem with A3Cのようなチュートリアルもあり何がしたいのか謎です……

また、PFNが最近Chainer Chemistryを公開しました。NFP, GGNN, Weave, SchNetなどのgraph convolution手法が実装されているほか、QM9, Tox21などの有名どころのデータセットを使うコードも揃っており、普段Chainerを使っている人はこれを試してみるのもよいかもしれません。

性質から分子を作る

創薬などの応用においては、「タンパク質Xの動きを抑制する」などの特定の性質を持った分子を作ることが必要になります。化学物質の構造と生物学的な活性の関係のことをQSARと呼びますが、逆に活性から構造を予測する問題をinverse-QSARのように言うことがあります。

分子設計の難しさの一つは、可能な分子の数が非常にたくさんあることです。[Bohacek+, 1996]によれば、C,N,O,Sを30個以下しか持たない分子に限っても 10^{60}種類の分子が存在できるとされています。そのため全探索は不可能なので何らかの効率的な探索法を考える必要があります。

Deep Learning以外の方法

創薬は重要な研究分野なので以前から研究が行われていました。多くの手法は[Nishibata+, 1991][Pierce+, 2004]のように既に知られている部分構造を組み合わせることで分子を設計しています。最近の研究では[Kawai+, 2014]のように構造の組み合わせに遺伝的アルゴリズム構造を使ったり、[Podlewska+, 2017]のように目的関数を機械学習の予測値にしたりするなどの工夫がなされています。

Deep Learningによる方法

分子設計にDeep Leaningを持ち込んだ研究が[Gómez-Bombarelli+, 2016]です。この研究では分子の文字列表現であるSMILES記法をvariational autoencoder (VAE) を用いて実数ベクトルに変換し、ベイズ最適化で最適化したベクトルをSMILESに戻すことで分子を設計しています。この手法の問題点はVAE空間上で最適化ベクトルをSMILESに戻したときに生成される文字列が文法的に正しくないなどの理由で分子と対応しなくなる率が非常に高かったことです。

SMILES記法は、グラフ構造として表される化合物を環を切り開くなどして文字列として表現できるようにしています。OpenSMILES specificationのように文脈自由文法で規定される文法を持っており、文法に従わない文字列は分子を表しません。(なお、文法に従っていても対応する分子が化学的に存在できるかは別の問題です)。例えば下のような図で表される分子のSMILESはO1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5となります。同じ数字はそこで環を形成していることを表し、カッコは分岐を表しています。

文法的に正しくないSMILESの文字列が生成される問題を解決するために、VAEの入出力にSMILESの文字列をそのまま使うのではなくSMILESを生成する文脈自由文法の生成規則列を使うことにしたのが[Kusner+, 2017]のGrammar Variational Autoencoderです。この研究で技術的に面白いところはVAE表現から文字列を生成する際にプッシュダウンオートマトンを考えて、現在スタックの一番上にある文字から選択できない生成規則の確率を0にする工夫を導入しているところです。この工夫により生成される文字列はSMILESの文法的に正しいものに限定することができるためデコードの効率が上がるほか、潜在空間自体もよりよいものになったと主張されています。

これらのアプローチに影響されたのかは分かりませんが、分子の構造を直接設計するのではなく、分子を表すSMILESを生成する研究が盛んに行われています。

  • [Segler+, 2017] はChEMBLのSMILESを学習したLSTMで新しいSMILESを生成しています。また、薬の候補になりそうな分子を入力としたRNNのファインチューニングなども行っています。

分子から反応を予測する

分子からの反応予測には、複数の分子を入力して反応結果を出力するものと、一つの分子を入力してその分子を作るのに必要な反応を予測するものがあります。

反応結果の予測

反応予測をコンピュータで行う試みは1960年代から行われていますが、従来の手法では専門家がルールをたくさん記述することで実現しています。この分野にもDeep Learningの波が来ています。

[Schwaller+, 2017] は反応物のSMILESを入力に生成物のSMILESを出力する言語モデルを用いて反応の予測を行っています。 SMILESによる反応の記述は、反応物の文字列を入力して生成物の文字を出力する処理なので、英語を入力してフランス語を入力する処理に似ていると彼らは考えましたアメリカの特許にある反応のデータベースから入力と出力のペアを作り、seq2seqという翻訳に使われるRNNモデルを適用して反応の予測を行いました。

結果としてtop-1で80%という先行研究を上回る精度の予測ができるようになったと主張されています。

逆合成の予測

目的の化合物を合成するための反応経路を求めることをretrosynthesisといいますが、実際に化合物を生産する上では非常に重要な技術です。この研究でもDeep Learningを使った論文が出ています。

[Segler+, 2017]ではAlphaGoと似た手法でretrosynthesisを行っています。(図は論文のFigure 1) (a)は目的の化合物(図ではIbuprofen)からはじめて分子をばらしていき、全てが既知の入手可能な分子(図では赤で示されている)にまで還元できたら逆合成が完了するというコンセプトを示しています。 (b)は(a)で用いられた既知の反応を示しています。 (c)は(a)の結果得られた反応経路から実際に目的の化合物を合成する過程を示しています。 (d)がこの論文の中心となるアイデアを表しています。現在の分子をばらすのに使える既知の反応はいくつもあります。反応の各段階を一つの状態ととらえると反応はグラフ上の状態遷移と考えることができ、逆合成はグラフ上の最適な経路を探す問題と解釈できます。そこでゲームの状態を表す木から最適な手を探すのと同じような方法を用いて、最適な次の反応を選ぶことで逆合成を解くことができると考えられます。 (e)のように分子の状態を入力すると良さそうな反応を返すDNNの確率をガイドにしたモンテカルロ木探索を実行することで逆合成を行うことができそうです。

論文の実験ではモンテカルロ木探索を用いた提案手法が先行研究よりも高い性能を示したと主張されています。

私が知っている主な研究はこれくらいですが、他にも面白い研究を知っている方がいらっしゃったらコメントなどで教えて下さい。

PyTorchでファインチューニングしたモデルをONNXで利用する

昨日の作業の結果、Illustration2Vecのモデルが大きすぎて貧弱なサーバーでは使えないことが分かりました。今のところ二次元画像判別器の特徴量抽出にしか使っていないので、もっと軽いモデルでも代用できるはずです。軽いモデルとして有名なSqueezenetをこれまで集めたデータでファインチューニングして様子を見てみることにします。

ファインチューニングとONNXへのエキスポート

PyTorchのチュートリアルが丁寧に説明してくれているので、これをコピペして継ぎ接ぎするだけです。

継ぎ接ぎしたものがこちらになります。これを実行するとmodel.onnxというファイルが作成されます。 ONNX版Illustration2Vecのモデルサイズが910Mに対して、このモデルは2.8MBなのでだいぶ小さくなりました。精度もだいたい同じくらいだと思います。

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
data_dir = "./data/images"

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 1

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
print(model_ft)

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

# Save PyTorch model to file
torch.save(model_ft.to('cpu').state_dict(), 'model.pth')

# Input to the model
x = torch.randn(1, 3, 224, 224, requires_grad=True)
torch_out = model_ft(x)

# Export the model
torch.onnx.export(model_ft,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",                # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})

ONNXモデルの利用

こうして作成したONNXモデルをPyTorchを使わずに利用するコードはこんな感じです。

import os
from PIL import Image

import numpy as np
import onnxruntime

# 中心を正方形に切り抜いてリサイズ
def crop_and_resize(img, size):
    width, height = img.size
    crop_size = min(width, height)
    img_crop = img.crop(((width - crop_size) // 2, (height - crop_size) // 2,
                         (width + crop_size) // 2, (height + crop_size) // 2))
    return img_crop.resize((size, size))

img_mean = np.asarray([0.485, 0.456, 0.406])
img_std = np.asarray([0.229, 0.224, 0.225])

ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

img = Image.open('image.jpg').convert('RGB')
img = crop_and_resize(img, 224)

# 画像の正規化
img_np = np.asarray(img).astype(np.float32)/255.0
img_np_normalized = (img_np - img_mean) / img_std

# (H, W, C) -> (C, H, W)
img_np_transposed = img_np_normalized.transpose(2, 0, 1)

batch_img = [img_np_transposed]

ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
ort_outs = ort_session.run(None, ort_inputs)[0]
batch_result = np.argmax(ort_outs, axis=1)
print(batch_result)

このSqueezenetモデルを使って昨日と同じようなことをするviews.pyが以下のようになります。全文はGitHubを見てください。 github.com

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import numpy as np
import onnxruntime
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def index(request):
    if request.user.is_authenticated:
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        timeline = api.home_timeline(count=200, tweet_mode='extended')

        tweet_media = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                tweet_media.append(tweet)

        batch_size = 4
        tweet_illust = []
        for batch_tweet in more_itertools.chunked(tweet_media, batch_size):
            batch_img = []
            for tweet in batch_tweet:
                media_url = tweet.extended_entities['media'][0]['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(
                    os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename).convert('RGB')
                img = data_transforms(img)
                batch_img.append(to_numpy(img))

            ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
            ort_outs = ort_session.run(None, ort_inputs)[0]
            batch_result = np.argmax(ort_outs, axis=1)
            for tweet, result in zip(batch_tweet, batch_result):
                if result == 1:
                    media_url = tweet.extended_entities['media'][0]['media_url']
                    if hasattr(tweet, "retweeted_status"):
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    text = re.sub(
                        r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str,
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request, 'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request, 'hello/index.html')

モデルがだいぶ小さくなったので、貧弱サーバーでも動かすことができました。

デプロイ

このモデルサイズならHerokuで動かせると思ったのですが、torchvisionなどの依存ライブラリの容量だけでHerokuの500MBの制限を超えてしまうようなので自分のサーバーで動かすことにしました。一応動いてはいるのですが、コールバックの設定がうまくいかないので後日直します。

2.9MBのモデルなら一瞬で推論できると思ったのですが、それでもまだ貧弱サーバーには荷が重いようで読み込みにだいぶ時間がかかります。もっと軽いモデルを作るかディープラーニングに頼らない方法を考えるのが良さそうです。

PyTorchのCPU版をrequirements.txtで指定すればHerokuにデプロイできました。

https://kivantium-playground.herokuapp.com/ から試すことができます。(開発状況によっては違うものがデプロイされているかもしれません)

広告コーナー