kivantium活動日記

プログラムを使っていろいろやります

RISC-Vクロスコンパイラで生成したバイナリを自作RISC-V上で実行する

4連休の課題としてFPGAで簡単なCPUを作っているので、その進捗を記録しておきます。

RISC-V (RV32I) の作成

とりあえず今回は確実に動くCPUを作ることを目標にしました。 パイプラインなどは実装せず、フェッチ→デコード→実行→メモリ・アクセス→書き戻しの5段階にそれぞれ1クロック使って、1命令に5クロックかける設計になっています。 命令セットにはRISC-Vの一番基本的な構成であるRV32Iを採用しましたが、簡単のため特権命令や割り込み周りは省略しました。 ハードウェアには以前使ったDigilentのBasys 3を使って、Vivadoで開発しました。 kivantium.hateblo.jp

あまり工夫したところはないので実装の詳細は説明しません。ソースコードはここに置いてあります。 github.com

RISC-Vクロスコンパイラのインストール

RISC-V向けのGCCriscv-gnu-toolchainというリポジトリで公開されています。以前はriscv-toolsというリポジトリ以下で公開されていたものが移動したらしいので、古い記事を参考にするときは気をつけてください。 github.com

READMEに従ってインストールするだけですが、configureで32bit用に設定する必要があります。ソースコードだけで10GB近くあるのでディスク容量に注意してください。

sudo apt-get install autoconf automake autotools-dev curl python3 libmpc-dev libmpfr-dev libgmp-dev gawk build-essential bison flex texinfo gperf libtool patchutils bc zlib1g-dev libexpat-dev
git clone --recursive https://github.com/riscv/riscv-gnu-toolchain
./configure --prefix=/opt/riscv32 --with-arch=rv32im --with-abi=ilp32d
make linux

/opt/riscv32以下にインストールされるのでPATHを通しておきます。

export PATH=/opt/riscv32/bin:$PATH

バイナリの作成

簡単な例として、フィボナッチ数列の第10番目の項を再帰で求めるプログラムを実行することにします。test.cを以下の通り作成します。main関数を抜けないようにするために最後に無限ループを実行しています。

int fib(int n) {
  if(n <= 1) return 1;
  return fib(n-1) + fib(n-2);
}

int main() {
  fib(10);
  for(;;) {}
  return 0;
}

これを自作CPUで動くようにコンパイルします。

普通にコンパイルしてしまうと未実装の命令を使った初期化ルーチンが走ってしまうので、まずはそれを無効にします。start.Sを以下の通り作成します。

.section .text.init;
.globl _start
_start:
    call main

何もせずにmainを呼び出すアセンブリになっています。

次に、命令を0番地から実行するように指定します。link.ldを以下の通り作成します。

OUTPUT_ARCH( "riscv" )
ENTRY(_start)

SECTIONS
{
  . = 0x00000000;
  .text.init : { *(.text.init) }
  .tohost : { *(.tohost) }
  .text : { *(.text) }
  .data : { *(.data) }
  .bss : { *(.bss) }
  _end = .;
}

以下のようにしてバイナリを生成します。

riscv32-unknown-elf-gcc -march=rv32i -c -o start.o start.S
riscv32-unknown-elf-gcc -march=rv32i -c -o test.o test.c
riscv32-unknown-elf-ld test.o start.o -lc -L/opt/riscv32/riscv32-unknown-elf/lib/ -Tlink.ld -nostartfiles -static -o test.elf
riscv32-unknown-elf-objcopy -O binary test.elf test.bin
hexdump -v -e '1/4 "%08x" "\n"' test.bin > test.hex

最後に出来上がるtest.hexには、CPUが実行する命令列が16進数で書かれています。

074000ef
fe010113
00112e23
(中略)
00a00513
f7dff0ef
0000006f

ELF形式のファイルに対してobjdumpを実行すると逆アセンブルした結果を見ることができます。

$ riscv32-unknown-elf-objdump -d test.elf

test.elf:     ファイル形式 elf32-littleriscv


セクション .text.init の逆アセンブル:

00000000 <_start>:
   0:   074000ef            jal ra,74 <main>

セクション .text の逆アセンブル:

00000004 <fib>:
   4:   fe010113            addi    sp,sp,-32
   8:   00112e23            sw  ra,28(sp)
   c:   00812c23            sw  s0,24(sp)
  10:   00912a23            sw  s1,20(sp)
  14:   02010413            addi    s0,sp,32
  18:   fea42623            sw  a0,-20(s0)
  1c:   fec42703            lw  a4,-20(s0)
  20:   00100793            li  a5,1
  24:   00e7c663            blt a5,a4,30 <fib+0x2c>
  28:   00100793            li  a5,1
  2c:   0300006f            j   5c <fib+0x58>
  30:   fec42783            lw  a5,-20(s0)
  34:   fff78793            addi    a5,a5,-1
  38:   00078513            mv  a0,a5
  3c:   fc9ff0ef            jal ra,4 <fib>
  40:   00050493            mv  s1,a0
  44:   fec42783            lw  a5,-20(s0)
  48:   ffe78793            addi    a5,a5,-2
  4c:   00078513            mv  a0,a5
  50:   fb5ff0ef            jal ra,4 <fib>
  54:   00050793            mv  a5,a0
  58:   00f487b3            add a5,s1,a5
  5c:   00078513            mv  a0,a5
  60:   01c12083            lw  ra,28(sp)
  64:   01812403            lw  s0,24(sp)
  68:   01412483            lw  s1,20(sp)
  6c:   02010113            addi    sp,sp,32
  70:   00008067            ret

00000074 <main>:
  74:   ff010113            addi    sp,sp,-16
  78:   00112623            sw  ra,12(sp)
  7c:   00812423            sw  s0,8(sp)
  80:   01010413            addi    s0,sp,16
  84:   00b00513            li  a0,11
  88:   f7dff0ef            jal ra,4 <fib>
  8c:   0000006f            j   8c <main+0x18>

C言語で書いた通り、再帰でフィボナッチ数を求めた後、`8c'を無限ループするプログラムになっていることが確認できます。

生成した16進数の命令列はSystemVerilogの$readmemhを利用して命令メモリに埋め込んでいます(ソースコード)。コードを変更するたびに論理合成をやり直す必要がありますが、命令列を外部から読み込ませるのは面倒なのでこうしました。

実機での動作

関数の引数と戻り値が入るa0レジスタの値を10進数で7セグLEDに表示する回路を組みました。しばらく計算した後、最終的な関数の返り値である89が表示されます。(動作を見やすくするためにクロックを10万分周しています)

今後の課題

xv6が去年からRISC-Vに対応したそうなので、xv6が動くようなCPUが作れると良いです。 OSを動かすためには割り込みなどの特権命令やスーパーバイザーモードでの仮想アドレスを実装しないといけないみたいので道のりは険しそうですが……

参考文献

ツールチェインの使い方が特に参考になりました。

github.com 上のFPGAマガジンで実装されているRISC-Vです(本誌からリンクされてなかった……)

1命令を5クロックで実行する設計はここから持ってきました(このコードではデコーダーやALUが順序回路になっていますが、自分の実装では組み合わせ回路になっています。2クロック無駄になっていますが、簡単のためです……)

mindchasers.com ツールチェインのコンパイルオプションはここを参考にしました。

自動微分ライブラリJAXを用いた対称行列の固有値の微分

JAXという自動微分ライブラリが流行りそうな機運があるので遊んでみます。

github.com

インストール

READMEに書いてある通りにやりました。

pip install --upgrade pip
pip install --upgrade jax jaxlib 

基本的な使い方

The Autodiff Cookbook — JAX documentation を読んで下さい

固有値の最小化の例

ここまでとても雑な説明だったのは、固有値の自動微分が今回の記事のメインテーマだからです。固有値の勾配を求める需要なんてないだろうと思っていたのですが、シュレディンガー方程式固有値方程式なので、エネルギーの最小化をするためには固有値を最小化する必要があり、そのために固有値の勾配が欲しいという需要があるようです。(この論文では実際に自動微分を使ってエネルギーの最小化を行っています)

JAXでは一般の固有値の勾配はまだ実装されていないようなので、対称行列の固有値の勾配を使った例題を解くことにします。

\displaystyle{
A=\left(
    \begin{array}{cc}
      1 &x \\
      x & -1
    \end{array}
  \right)
}

として、行列Aの最大固有値が最小になるようにしてみます。(ここでは実数の範囲のみを考えることにします)固有方程式を解くと、固有値λは

\displaystyle{
\lambda = \pm\sqrt{x^2+1}
}

となるので、最大固有値の最小値はx=0のときに1.0になります

xを与えたときにAの最大固有値を返す関数をfunc最急降下法で最小化するプログラムは以下のようになります。

import jax.numpy as np
from jax.ops import index, index_update
from jax import grad, jit

@jit
def func(x):
    A = np.zeros((2, 2))
    A = index_update(A, index[0, 0], 1)
    A = index_update(A, index[0, 1], x)
    A = index_update(A, index[1, 0], x)
    A = index_update(A, index[1, 1], -1)
    w, v = np.linalg.eigh(A)
    return np.max(w).real

func_grad = jit(grad(func))
alpha = 0.1
x = 1.0
for _ in range(100):
    print("x={}, f(x)={}".format(x, func(x)))
    x -= alpha * func_grad(x)

print("min value: {} (x={})".format(func(x), x))

コメント

  • 関数に@jitをつけるとJITコンパイルが行われて実行が高速になります。jit(grad(func))のようにすると、勾配関数もJITコンパイルできます。
  • JAXで配列に添字でアクセスするとエラーになるので、代わりにindex_updateを使っています。詳細は🔪 JAX - The Sharp Bits 🔪 — JAX documentationを見てください。(さすがにこれはあまりに汚いのでもう少しいいやり方があるかもしれません)
  • 固有値を求めるためにjax.scipy.linalg.eighを使っています。この関数は本来エルミート行列用なので、固有値複素数で返ってきます。ここでは問題を実数の範囲に限定しているので実部を取っています。

実行結果は次のようになりました。

x=1.0, f(x)=1.4142135381698608
x=0.9292893409729004, f(x)=1.3651295900344849
x=0.8612160086631775, f(x)=1.3197321891784668
x=0.7959591150283813, f(x)=1.2781044244766235
x=0.7336825728416443, f(x)=1.2402782440185547
(中略)
x=6.46666157990694e-05, f(x)=1.0
x=5.819995567435399e-05, f(x)=1.0
x=5.237996083451435e-05, f(x)=1.0
x=4.714196620625444e-05, f(x)=1.0
x=4.242776776663959e-05, f(x)=1.0
x=3.818498953478411e-05, f(x)=1.0
min value: 1.0 (x=3.43664905813057e-05)

最急降下法がそれっぽく動いていることが確認できました。この程度の問題では自動微分を使うまでもないですが、問題がもう少し複雑になるとJITによる高速な自動微分のメリットを享受できます。 今回は最急降下法を使いましたが、scipyと組み合わせればBFGSなどのもう少し高度な最適化手法を使うことができます(参考: A brief introduction to JAX and Laplace’s method - anguswilliams91.github.io

固有値微分とは一体なんなのかという話もする必要があるのですが、まだよく理解できていないので今日はとりあえずここまでにして後日更新します。

メモ欄

added jvp rule for eigh, tests by levskaya · Pull Request #358 · google/jax · GitHub

Django ChannelsでWebsocket通信を行ってVue.jsで表示する

前回の続きです。 kivantium.hateblo.jp

サーバー上で重たい処理をする場合、処理が全て終わってからHTMLを生成しているとかなりレスポンスが悪くなってしまいます。先に結果表示ページのHTMLを表示しておいて、中身を後からWebsocketで順次通信するといい感じになりそうです。Websocket通信したデータを表示するためにVue.jsを使うことにします。

前回からの変更点

前回のHello, worldで使ったディレクトリ構成を編集する形で作業を進めます。

mysite/routing.py を以下の内容で書き換えます。ws/test/にアクセスすることでWebsocket通信を行えるように設定しています。

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path

from . import consumers

websocket_urlpatterns = [
    path('ws/test/', consumers.Consumer),
]

application = ProtocolTypeRouter({
    # (http->django views is added by default)
    'websocket': AuthMiddlewareStack(
        URLRouter(
            websocket_urlpatterns
        )
    ),
})

mysite/consumers.py にWebsocket通信で行う処理を記述します。本当は非同期処理をするべきらしいのですが、動く書き方が分からなかったので同期処理で書いています。(参考: django で Websocket - 空のブログ

import json
import time
import threading

from channels.generic.websocket import WebsocketConsumer

class Consumer(WebsocketConsumer):
    # 接続されたときの処理
    def connect(self):
        # 接続を許可する
        self.accept()
        # メッセージ送信関数を新しいスレッドで呼び出す
        # 本当は非同期で書くべきだがうまく動かなかった
        self.sending = True
        self.sender = threading.Thread(
            target=self.send_message, args=('Hello', ))
        self.sender.start()

    # 接続が切断されたときの処理
    def disconnect(self, close_code):
        # スレッドを終了するフラグを立てる
        self.sending = False
        # スレッドの終了を待つ
        self.sender.join()

    # メッセージを受け取ったときの処理
    def receive(self, text_data):
        text_data_json = json.loads(text_data)
        message = text_data_json['message']
        print("Received message:", message)

    # メッセージ送信関数
    def send_message(self, message):
        while True:
            # 終了フラグが立っていたら終了する
            if not self.sending:
                break
            self.send(text_data=json.dumps({
                'message': message,
            }))
            time.sleep(1)

mysite/templates/mysite/index.html にこれと通信するためのHTMLとJavaScriptを書きます。

<!doctype html>
<html>
  <head>
    <title>test</title>
    <script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
  </head>
  <body>
    <div id="app">
      <ul >
        <li v-for="message in messages">[[ message ]]</li>
      </ul>
    </div>
    <script>
      var vm = new Vue({
        el: '#app',
        // Vueの記号がDjangoと被らないようにする
        delimiters: ['[[', ']]'],
        data: {
          messages: [],
          ws: new WebSocket(
            // http or https の判定
            // https://www.koatech.info/blog/vue-websocket-sample/
            (window.location.protocol == "https:" ? "wss" : "ws")
            + '://' + window.location.host + '/ws/test/'
          ),
        },
        created: function() {
          // Vueオブジェクトにアクセスするために変数化する
          // https://www.koatech.info/blog/vue-websocket-sample/
          const self = this;

          // メッセージを受け取ったときの処理
          self.ws.onmessage = function(e) {
            const data = JSON.parse(e.data);
            console.log(data.message)
            console.log(vm.messages)
            // メッセージを表示する
            vm.messages.push(data.message);
            // サーバーにメッセージを送る
            self.ws.send(JSON.stringify({
              'message': "Konnichiwa",
            }));
          };

          // 接続が切断されたときの処理
          self.ws.onclose = function(e) {
            console.error('Websocket has been closed unexpectedly');
          };
        }
      })
    </script>
  </body>
</html>

mysite/views.py を次のように書き換えてindex.htmlを表示するようにします。

from django.shortcuts import render

def index(request):
    return render(request, 'mysite/index.html')

最後に、テンプレートを認識させるために mysite/settings.pyINSTALLED_APPSmysite を追加します。

INSTALLED_APPS = [
    ......
    'channels',
    'mysite',
]

以上の変更を行ったものをデプロイしてブラウザで閲覧すると1秒ごとにHelloが追加される様子が確認できます。

f:id:kivantium:20200502150709g:plain
スクリーンショット

自分の目的にはこれで十分でしたが、チャットなどのより高度なことをしたい場合はChennelsのドキュメントを読んで下さい。 channels.readthedocs.io

Django ChennelsアプリをNginxとSupervisorでデプロイする

DjangoでWebsocketを使うときにはChannelsというライブラリがよく使われています。これまではHerokuにデプロイをしてきましたが、HerokuとChannelsの相性が良くないのかすぐに接続が切れてしまうので、これからはAWS上で開発しようと思いました。公式ドキュメントを読んでもデプロイ方法がよく分からなかったのでメモしておきます。

AWS LightsailでUbuntu 18.04のインスタンスを立てたとして、SSHで入ってからHello, world!するところまでを見ていきます。

ライブラリのインストール

Daphneの起動を楽にするためにvenvを使います。(参考: Djangoのインストール · Django Girls Tutorial

ssh ubuntu@xxx.xxx.xxx.xxx
sudo apt update
sudo apt install python3-venv
python3 -m venv env
source env/bin/activate

requirements.txtに以下を記述します。

django~=3.0.5
channels~=2.4.0

pipをアップデートしてから必要なライブラリをインストールします。

python -m pip install --upgrade pip
pip install -r requirements.txt

Hello, world!アプリの設定

Django プロジェクトを作成します。

django-admin startproject mysite
cd mysite

基本的にはChannels公式ドキュメントのInstallationに従って設定します。簡単のために本番環境と開発環境の設定の分離などは無視します。

mysite/settings.py を以下のように編集します。diffを示しています。

-ALLOWED_HOSTS = []
+ALLOWED_HOSTS = ['*']  # 本当は適切なホストを指定するべきだが簡単のため全て許可
 
(略)

INSTALLED_APPS = [
     'django.contrib.sessions',
     'django.contrib.messages',
     'django.contrib.staticfiles',
+    'channels',
 ]
 
+ASGI_APPLICATION = "mysite.routing.application"
+
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
     'django.contrib.sessions.middleware.SessionMiddleware',

mysite/routing.py を以下の内容で作成します。

from channels.routing import ProtocolTypeRouter

application = ProtocolTypeRouter({
    # Empty for now (http->django views is added by default)
})

mysite/urls.py を以下の内容で作成します。

from django.urls import path

from . import views

urlpatterns = [
    path('', views.index, name='index'),
]

mysite/views.py を以下の内容で作成します。

from django.http import HttpResponse


def index(request):
    return HttpResponse("Hello, world!")

ここまで来たら python manage.py runserver を実行してエラーが出ないことだけ確認します。(サーバー上にあるのでこの時点ではブラウザで表示確認ができません)

NginxとSupervisorの設定

ここもChannels公式ドキュメントのDeployingに従って設定するだけなのですが、この通りにやっても動かなかったので以下のStackOverflowに従ってアレンジしました。

stackoverflow.com

まずはNginxとSupervisorをインストールします。

sudo apt install nginx supervisor

mysite/asgi.pyを以下の内容で作成します。

"""
ASGI entrypoint. Configures Django and then runs the application
defined in the ASGI_APPLICATION setting.
"""

import os
import django
from channels.routing import get_default_application

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings")
django.setup()
application = get_default_application()

/etc/supervisor/conf.d/asgi.confを以下の内容で作成します。

[fcgi-program:asgi]
# TCP socket used by Nginx backend upstream
socket=tcp://localhost:8000

# Directory where your site's project files are located
directory=/home/ubuntu/mysite

# Each process needs to have a separate socket file, so we use process_num
# Make sure to update "mysite.asgi" to match your project name
command=/home/ubuntu/env/bin/daphne --fd 0 --access-log - --proxy-headers mysite.asgi:application
# Number of processes to startup, roughly the number of CPUs you have
numprocs=4

# Give each process a unique name so they can be told apart
process_name=asgi%(process_num)d

# Automatically start and recover processes
autostart=true
autorestart=true

# Choose where you want your log to go
stdout_logfile=/var/log/asgi.log
redirect_stderr=true

設定を読み込みます。

sudo supervisorctl reread
sudo supervisorctl update

/etc/nginx/sites-available/defaultを以下のように編集します。

upstream channels-backend {
    server localhost:8000;
}
...
server {
    ...
    location / {
        try_files $uri @proxy_to_app;
    }
    ...
    location @proxy_to_app {
        proxy_pass http://channels-backend;

        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        proxy_redirect off;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Host $server_name;
    }
    ...
}

設定を読み込みます。

sudo service nginx reload

ブラウザのアドレス欄にこのサーバーのIPアドレスを入力すれば、Hello, world! と出力されたページを確認することができます。 次回はこのサーバーを使って簡単なWebsocketを使ったプログラムを書きます。

化合物でもDeep Learningがしたい!

この記事は2017年12月15日に https://kivantium.net/deep-for-chem/ に投稿したものです。 情報が古くなっていますが、まだ参照されているようなので再掲します。

この記事はDeep Learningやっていき Advent Calendar 2017の15日目です。

Deep Learningの威力が有名になったのは画像認識コンテストでの圧勝がきっかけでしたが、今ではDeep Learningはあらゆる分野に応用され始めています。NIPS2017でもMachine Learning for Molecules and Materialsが開催されたように、物質化学における機械学習の存在感が高まりつつあります。この記事ではその一例として化学の研究にDeep Learningが使われている例を紹介していきます。

化学物質の研究に機械学習が使われる主なパターンには

  • 分子を入力するとその分子の性質を出力する
  • 分子の性質を入力するとその性質を持った分子を出力する
  • 分子を入力するとその反応を出力する

の3つがあります。それぞれについて詳しく説明します。

分子から性質を予測する

Deep Learning以前

Deep Learning以前の性質予測では、職人の温かい手作りによる特徴量が使われていました。分子の特徴ベクトルはmolecular fingerprintsと呼ばれます。molecular fingerprintsは化合物の特徴的な一部分(fragmentと呼ばれる)がその分子にあるかどうかを0/1で表したbitを並べて作られます。

(画像はFingerprints in the RDKit p.4より引用)

どのfragmentを用いるのが有効かはデータセット・問題に依存するので様々な種類のfingerprintが提案されてきました。

主なfingerprintを挙げると

などがあります。 fingerprintはRDKitなどのライブラリを使うと簡単に計算できます。(各ソフトで計算できるfingerprintのリスト

このようなfingerprintを使ってSVMやRandomforestなどでその分子がある性質を持つ/持たないを予測する研究がたくさんあります。化学の分野でDeep Learningが大きく注目されるきっかけになったのは、kaggleの薬の活性予測のコンペでHintonらのチームが優勝したことですが、論文を見ると特徴量には上のように設計されたものを使っており、ニューラルネットワークで設計されたものではなかったようです。

graph convolutionの登場

fingerprintの設計にニューラルネットワークが導入されたのが[Duvenaud+, 2015]です。この研究ではcircular fingerprint (上のECFPのこと)をもとにneural graph fingerprint (NFP)を提案しています。以下にアルゴリズムを示します。

従来のfingerprint設計でhashやmodになっていた部分が重みを調整できる演算に変更されています。これにより、予測にとって重要なfragmentの寄与は大きく、重要ではないfragmentの寄与は小さくなるような特徴量が設計できるようになりました。実際に分子の水への溶けやすさをNFPで予測したところ水への溶けやすさに影響するR-OHのような構造の重みが大きくなったことが報告されています。

NFPの他にも分子のグラフ構造に基づいたニューラルネットワークベースの特徴量設計の研究が行われています。これらはグラフ構造に注目したニューラルネットワークなので総称としてgraph convolutionと呼ばれています。一番有名なのはGoogle BrainのNeural Message Passing for Quantum Chemistryでしょう。この論文ではMessage Passing Neural Network (MPNN) というグラフ上のニューラルネットワークを提案し、分子のニューラルネットワークの先行研究の多くがMPNNで一般的に記述できることを主張した上で、MPNNが分子の性質を予測する上で高い性能を発揮すると主張しています。 MPNNは

という式で表されます。グラフ上で隣接するエッジからのメッセージ Mを足し合わせるような処理をしていることが分かります。 M, Uなどをうまく定めることで各種のgraph convolutionを表すことができます。詳細は論文を読んでください。Google ResearchによるブログPredicting Properties of Molecules with Machine Learningも役立つかもしれません。(ちなみにこの論文のラストオーサーは先述したkaggleコンペの論文のファーストオーサーです)

graph convolutionのイメージとしてよく使われる絵が[Han Altae-Tran+, 2017]にあります。一枚引用します。

Graph convolutionではない方法としては[Goh+, 2017]のような分子を画像にしてCNNで予測するようなものもあります。

ちなみに、同じ人がつい先日SMILES2Vecという文字列から化合物の性質を予測する論文も書いていました。

実装

分子に対するDeep Learningのライブラリで最も有名なのはDeepChemでしょう。DeepChemはTensorFlowでgraph convolutionを実装しています。Graph Convolutions For Tox21などのチュートリアルを読むとだいたい使い方が分かるのではないでしょうか(私も使ったことはないです)。ちなみに、なぜかPong in DeepChem with A3Cのようなチュートリアルもあり何がしたいのか謎です……

また、PFNが最近Chainer Chemistryを公開しました。NFP, GGNN, Weave, SchNetなどのgraph convolution手法が実装されているほか、QM9, Tox21などの有名どころのデータセットを使うコードも揃っており、普段Chainerを使っている人はこれを試してみるのもよいかもしれません。

性質から分子を作る

創薬などの応用においては、「タンパク質Xの動きを抑制する」などの特定の性質を持った分子を作ることが必要になります。化学物質の構造と生物学的な活性の関係のことをQSARと呼びますが、逆に活性から構造を予測する問題をinverse-QSARのように言うことがあります。

分子設計の難しさの一つは、可能な分子の数が非常にたくさんあることです。[Bohacek+, 1996]によれば、C,N,O,Sを30個以下しか持たない分子に限っても 10^{60}種類の分子が存在できるとされています。そのため全探索は不可能なので何らかの効率的な探索法を考える必要があります。

Deep Learning以外の方法

創薬は重要な研究分野なので以前から研究が行われていました。多くの手法は[Nishibata+, 1991][Pierce+, 2004]のように既に知られている部分構造を組み合わせることで分子を設計しています。最近の研究では[Kawai+, 2014]のように構造の組み合わせに遺伝的アルゴリズム構造を使ったり、[Podlewska+, 2017]のように目的関数を機械学習の予測値にしたりするなどの工夫がなされています。

Deep Learningによる方法

分子設計にDeep Leaningを持ち込んだ研究が[Gómez-Bombarelli+, 2016]です。この研究では分子の文字列表現であるSMILES記法をvariational autoencoder (VAE) を用いて実数ベクトルに変換し、ベイズ最適化で最適化したベクトルをSMILESに戻すことで分子を設計しています。この手法の問題点はVAE空間上で最適化ベクトルをSMILESに戻したときに生成される文字列が文法的に正しくないなどの理由で分子と対応しなくなる率が非常に高かったことです。

SMILES記法は、グラフ構造として表される化合物を環を切り開くなどして文字列として表現できるようにしています。OpenSMILES specificationのように文脈自由文法で規定される文法を持っており、文法に従わない文字列は分子を表しません。(なお、文法に従っていても対応する分子が化学的に存在できるかは別の問題です)。例えば下のような図で表される分子のSMILESはO1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5となります。同じ数字はそこで環を形成していることを表し、カッコは分岐を表しています。

文法的に正しくないSMILESの文字列が生成される問題を解決するために、VAEの入出力にSMILESの文字列をそのまま使うのではなくSMILESを生成する文脈自由文法の生成規則列を使うことにしたのが[Kusner+, 2017]のGrammar Variational Autoencoderです。この研究で技術的に面白いところはVAE表現から文字列を生成する際にプッシュダウンオートマトンを考えて、現在スタックの一番上にある文字から選択できない生成規則の確率を0にする工夫を導入しているところです。この工夫により生成される文字列はSMILESの文法的に正しいものに限定することができるためデコードの効率が上がるほか、潜在空間自体もよりよいものになったと主張されています。

これらのアプローチに影響されたのかは分かりませんが、分子の構造を直接設計するのではなく、分子を表すSMILESを生成する研究が盛んに行われています。

  • [Segler+, 2017] はChEMBLのSMILESを学習したLSTMで新しいSMILESを生成しています。また、薬の候補になりそうな分子を入力としたRNNのファインチューニングなども行っています。

分子から反応を予測する

分子からの反応予測には、複数の分子を入力して反応結果を出力するものと、一つの分子を入力してその分子を作るのに必要な反応を予測するものがあります。

反応結果の予測

反応予測をコンピュータで行う試みは1960年代から行われていますが、従来の手法では専門家がルールをたくさん記述することで実現しています。この分野にもDeep Learningの波が来ています。

[Schwaller+, 2017] は反応物のSMILESを入力に生成物のSMILESを出力する言語モデルを用いて反応の予測を行っています。 SMILESによる反応の記述は、反応物の文字列を入力して生成物の文字を出力する処理なので、英語を入力してフランス語を入力する処理に似ていると彼らは考えましたアメリカの特許にある反応のデータベースから入力と出力のペアを作り、seq2seqという翻訳に使われるRNNモデルを適用して反応の予測を行いました。

結果としてtop-1で80%という先行研究を上回る精度の予測ができるようになったと主張されています。

逆合成の予測

目的の化合物を合成するための反応経路を求めることをretrosynthesisといいますが、実際に化合物を生産する上では非常に重要な技術です。この研究でもDeep Learningを使った論文が出ています。

[Segler+, 2017]ではAlphaGoと似た手法でretrosynthesisを行っています。(図は論文のFigure 1) (a)は目的の化合物(図ではIbuprofen)からはじめて分子をばらしていき、全てが既知の入手可能な分子(図では赤で示されている)にまで還元できたら逆合成が完了するというコンセプトを示しています。 (b)は(a)で用いられた既知の反応を示しています。 (c)は(a)の結果得られた反応経路から実際に目的の化合物を合成する過程を示しています。 (d)がこの論文の中心となるアイデアを表しています。現在の分子をばらすのに使える既知の反応はいくつもあります。反応の各段階を一つの状態ととらえると反応はグラフ上の状態遷移と考えることができ、逆合成はグラフ上の最適な経路を探す問題と解釈できます。そこでゲームの状態を表す木から最適な手を探すのと同じような方法を用いて、最適な次の反応を選ぶことで逆合成を解くことができると考えられます。 (e)のように分子の状態を入力すると良さそうな反応を返すDNNの確率をガイドにしたモンテカルロ木探索を実行することで逆合成を行うことができそうです。

論文の実験ではモンテカルロ木探索を用いた提案手法が先行研究よりも高い性能を示したと主張されています。

私が知っている主な研究はこれくらいですが、他にも面白い研究を知っている方がいらっしゃったらコメントなどで教えて下さい。

PyTorchでファインチューニングしたモデルをONNXで利用する

昨日の作業の結果、Illustration2Vecのモデルが大きすぎて貧弱なサーバーでは使えないことが分かりました。今のところ二次元画像判別器の特徴量抽出にしか使っていないので、もっと軽いモデルでも代用できるはずです。軽いモデルとして有名なSqueezenetをこれまで集めたデータでファインチューニングして様子を見てみることにします。

ファインチューニングとONNXへのエキスポート

PyTorchのチュートリアルが丁寧に説明してくれているので、これをコピペして継ぎ接ぎするだけです。

継ぎ接ぎしたものがこちらになります。これを実行するとmodel.onnxというファイルが作成されます。 ONNX版Illustration2Vecのモデルサイズが910Mに対して、このモデルは2.8MBなのでだいぶ小さくなりました。精度もだいたい同じくらいだと思います。

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
data_dir = "./data/images"

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 1

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
print(model_ft)

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

# Save PyTorch model to file
torch.save(model_ft.to('cpu').state_dict(), 'model.pth')

# Input to the model
x = torch.randn(1, 3, 224, 224, requires_grad=True)
torch_out = model_ft(x)

# Export the model
torch.onnx.export(model_ft,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",                # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})

ONNXモデルの利用

こうして作成したONNXモデルをPyTorchを使わずに利用するコードはこんな感じです。

import os
from PIL import Image

import numpy as np
import onnxruntime

# 中心を正方形に切り抜いてリサイズ
def crop_and_resize(img, size):
    width, height = img.size
    crop_size = min(width, height)
    img_crop = img.crop(((width - crop_size) // 2, (height - crop_size) // 2,
                         (width + crop_size) // 2, (height + crop_size) // 2))
    return img_crop.resize((size, size))

img_mean = np.asarray([0.485, 0.456, 0.406])
img_std = np.asarray([0.229, 0.224, 0.225])

ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

img = Image.open('image.jpg').convert('RGB')
img = crop_and_resize(img, 224)

# 画像の正規化
img_np = np.asarray(img).astype(np.float32)/255.0
img_np_normalized = (img_np - img_mean) / img_std

# (H, W, C) -> (C, H, W)
img_np_transposed = img_np_normalized.transpose(2, 0, 1)

batch_img = [img_np_transposed]

ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
ort_outs = ort_session.run(None, ort_inputs)[0]
batch_result = np.argmax(ort_outs, axis=1)
print(batch_result)

このSqueezenetモデルを使って昨日と同じようなことをするviews.pyが以下のようになります。全文はGitHubを見てください。 github.com

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import numpy as np
import onnxruntime
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def index(request):
    if request.user.is_authenticated:
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        timeline = api.home_timeline(count=200, tweet_mode='extended')

        tweet_media = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                tweet_media.append(tweet)

        batch_size = 4
        tweet_illust = []
        for batch_tweet in more_itertools.chunked(tweet_media, batch_size):
            batch_img = []
            for tweet in batch_tweet:
                media_url = tweet.extended_entities['media'][0]['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(
                    os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename).convert('RGB')
                img = data_transforms(img)
                batch_img.append(to_numpy(img))

            ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
            ort_outs = ort_session.run(None, ort_inputs)[0]
            batch_result = np.argmax(ort_outs, axis=1)
            for tweet, result in zip(batch_tweet, batch_result):
                if result == 1:
                    media_url = tweet.extended_entities['media'][0]['media_url']
                    if hasattr(tweet, "retweeted_status"):
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    text = re.sub(
                        r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str,
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request, 'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request, 'hello/index.html')

モデルがだいぶ小さくなったので、貧弱サーバーでも動かすことができました。

デプロイ

このモデルサイズならHerokuで動かせると思ったのですが、torchvisionなどの依存ライブラリの容量だけでHerokuの500MBの制限を超えてしまうようなので自分のサーバーで動かすことにしました。一応動いてはいるのですが、コールバックの設定がうまくいかないので後日直します。

2.9MBのモデルなら一瞬で推論できると思ったのですが、それでもまだ貧弱サーバーには荷が重いようで読み込みにだいぶ時間がかかります。もっと軽いモデルを作るかディープラーニングに頼らない方法を考えるのが良さそうです。

PyTorchのCPU版をrequirements.txtで指定すればHerokuにデプロイできました。

追記: ヒストグラムの表示

最適なしきい値を見つけるためのヒストグラム表示は次のようにすればできます。

import os
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import onnxruntime

img_mean = np.asarray([0.485, 0.456, 0.406])
img_std = np.asarray([0.229, 0.224, 0.225])

ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

# https://stackoverflow.com/questions/34968722/
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

probs = []

path = 'test'
files = os.listdir(path)
filenames = [os.path.join(path, f) for f in files if os.path.isfile(os.path.join(path, f))]

for name in filenames:
    img = Image.open(name).convert('RGB')
    img = img.resize((224, 224))

    # Normalization
    img_np = np.asarray(img).astype(np.float32)/255.0
    img_np_normalized = (img_np - img_mean) / img_std
    # (H, W, C) -> (C, H, W)
    img_np_transposed = img_np_normalized.transpose(2, 0, 1)
    batch_img = [img_np_transposed]

    ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
    ort_outs = ort_session.run(None, ort_inputs)[0][0]
    prob = softmax(ort_outs)[1]
    probs.append(prob)

plt.hist(probs)
plt.show()

タイムラインから二次元イラストだけを表示するWebアプリの作成

ここまでの成果を使って、タイムラインから二次元イラストだけを表示するTwitterクライアントっぽいWebアプリを作成します。

スクリーンショット

出来上がったものがこちらになります。

f:id:kivantium:20200423234030p:plain:w600
スクリーンショット

以下、コードと今後の課題を述べます。

コード

コード全文はGitHubを見てください。 github.com

簡単のため、ログイン済みユーザーがアクセスするたびにタイムラインから最新のツイート200件を読み込んで、二次元画像判別器が二次元イラストだと判定した画像つきツイートを表示することにしました。200件以上のツイートを同時に読み込むのはTwitter APIの制限上難しかったです。 リツイートに関しては、リツイートした人の情報ではなくリツイート元の情報を表示することにします。英語で280文字までツイートできるようにする最近の仕様変更に対応するために少し面倒な処理を行っています。(参照: Extended Tweets — tweepy 3.8.0 documentation

前回作成したRandom Forestによる判定器や、ONNX版Illustration2Vecをhello/以下に置いています。

hello/views.py

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import sys
sys.path.append(os.path.dirname(__file__))
import i2v

# ONNX版Illustration2Vec
illust2vec = i2v.make_i2v_with_onnx(os.path.join(os.path.dirname(__file__), "illust2vec_ver200.onnx"))

# 事前に作成しておいた二次元画像判別器
clf = load(os.path.join(os.path.dirname(__file__), "clf.joblib"))

def index(request):
    if request.user.is_authenticated:  # Twitterでログインしている場合
        # ユーザー情報の取得
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        # 全文を取得するためにextendedを指定する
        timeline = api.home_timeline(count=200, tweet_mode = 'extended')

        tweet_illust = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                media  = tweet.extended_entities['media'][0]
                media_url = media['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename)
                feature = illust2vec.extract_feature([img])
                prob = clf.predict_proba(feature)[0]
                if prob[1] > 0.4:  # 二次元イラストの可能性が高い
                    if hasattr(tweet, "retweeted_status"): 
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    # リツイート元のツイート全文の取得
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    # 画像URLを削除するために文末のURLを削除する
                    text = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str, 
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        # 表示の都合上4つずつのグループに分ける
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request,'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request,'hello/index.html')

これを表示するためのHTMLを示します。Bulmaで画像の中央を丸く切り取って並べる - kivantium活動日記の応用です。Bulmaのカード機能を使っています。 bulma.io

<!doctype html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>にじさーち</title>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.8.0/css/bulma.min.css">
    <style>
img { object-fit: cover; }
.bm--card-equal-height {
   display: flex;
   flex-direction: column;
   height: 100%;
}
.bm--card-equal-height .card-footer {
   margin-top: auto;
} </style>
  </head>
  <body>
  <nav class="navbar is-primary">
    <div class="navbar-brand navbar-item">
      <h1 class="title has-text-light">にじさーち</h1>
    </div>
    {% if request.user.is_authenticated %}
    <div class="navbar-end">
      <div class="navbar-item">
      <a class="button is-light" href="/logout">ログアウト</a>
      </div>
    {% endif %}
    </div>
  </nav>
  <section class="section">
    <div class="container">
      {% if request.user.is_authenticated %}
      {% for tweets in timeline_chunked %}
      <div class="columns is-mobile">
        {% for tweet in tweets %}
        <div class="column is-3">
          <div class="card bm--card-equal-height">
            <div class="card-content">
              <div class="media">
                <div class="media-left">
                  <figure class="image is-48x48">
                    <img src="{{ tweet.profile_image_url }}" alt="Profile image">
                  </figure>
                </div>
                <div class="media-content">
                  <p class="title is-4">{{ tweet.author.name }}</p>
                  <p class="subtitle is-6">@{{ tweet.author.screen_name }}</p>
                </div>
              </div>
              <div class="card-image">
                <figure class="image is-square">
                  <a href="https://twitter.com/i/web/status/{{ tweet.id_str }}" target="_blank" rel="noopener noreferrer">
                    <img src="{{ tweet.image_url }}" alt="main image">
                  </a>
                </figure>
              </div>
              <div class="content">{{ tweet.text }}</div>
            </div>
          </div>
        </div>
        {% endfor %}
      </div>
      {% endfor %}
      {% else %}
      <p>あなたはログインしていません</p>
      <button type="button" onclick="location.href='{% url 'social:begin' 'twitter' %}'">Twitterでログイン</button>
      {% endif %}
    </div>
  </section>
  </body>
</html>

今後の課題

Illustration2Vecのモデルが重い

今回作成したアプリをサーバーにデプロイしようと思ったのですが、Illustration2Vecのモデルがサーバーのメモリサイズよりも大きかったためデプロイすることができませんでした。また、今後複数のユーザーによる使用をサポートしようとするとアクセスが来るたびにIllustration2Vecを実行していてはとても追いつかないのでモデルを軽量化することが必要になりそうです。

画像データベースとしての利用

Twitter APIのRate Limitが厳しいため、タイムラインから一度に収集できるツイートは200件くらいしかありません。これでは大量の画像を閲覧する目的に向きません。そのため、status/filterで常に画像を収集しておき画像データベースとして利用することが考えられます。しかし、類似のサービスが(利用規約に則っているにも関わらず)以前大炎上したことがあるっぽいので、Twitterクライアントとして一般に認められる以上の機能を提供するとなると面倒くさそうです。 nlab.itmedia.co.jp

自動タグ付け

Illustration2Vecでもタグ付けを行うことができますが、つけることができるタグの種類は有限です。新しく増える作品やキャラに対応するために何らかの方法で類似画像のハッシュタグからタグを推定して自動タグ付けができるとよさそうです。(これも絵師界隈の自主ルールで難癖つけられて面倒なことになりそうですが……)

Display requirementsへの適合

利用規約に則ってTwitterのコンテンツを表示する際の条件としてDisplay Requirementsというものがあります。 developer.twitter.com

ツイート本文を全文表示しないといけないとか、Twitterのロゴを右上に表示しないといけないなどといったユーザーの利便性を損なう規定なのですが、規定なので従う必要があります。 Google画像検索のようにサムネイルだけ表示してあとはツイートへのリンクにする方式にすることも含めて検討していきたいです。

広告コーナー