kivantium活動日記

プログラムを使っていろいろやります

Django ChannelsでWebsocket通信を行ってVue.jsで表示する

前回の続きです。 kivantium.hateblo.jp

サーバー上で重たい処理をする場合、処理が全て終わってからHTMLを生成しているとかなりレスポンスが悪くなってしまいます。先に結果表示ページのHTMLを表示しておいて、中身を後からWebsocketで順次通信するといい感じになりそうです。Websocket通信したデータを表示するためにVue.jsを使うことにします。

前回からの変更点

前回のHello, worldで使ったディレクトリ構成を編集する形で作業を進めます。

mysite/routing.py を以下の内容で書き換えます。ws/test/にアクセスすることでWebsocket通信を行えるように設定しています。

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path

from . import consumers

websocket_urlpatterns = [
    path('ws/test/', consumers.Consumer),
]

application = ProtocolTypeRouter({
    # (http->django views is added by default)
    'websocket': AuthMiddlewareStack(
        URLRouter(
            websocket_urlpatterns
        )
    ),
})

mysite/consumers.py にWebsocket通信で行う処理を記述します。本当は非同期処理をするべきらしいのですが、動く書き方が分からなかったので同期処理で書いています。(参考: django で Websocket - 空のブログ

import json
import time
import threading

from channels.generic.websocket import WebsocketConsumer

class Consumer(WebsocketConsumer):
    # 接続されたときの処理
    def connect(self):
        # 接続を許可する
        self.accept()
        # メッセージ送信関数を新しいスレッドで呼び出す
        # 本当は非同期で書くべきだがうまく動かなかった
        self.sending = True
        self.sender = threading.Thread(
            target=self.send_message, args=('Hello', ))
        self.sender.start()

    # 接続が切断されたときの処理
    def disconnect(self, close_code):
        # スレッドを終了するフラグを立てる
        self.sending = False
        # スレッドの終了を待つ
        self.sender.join()

    # メッセージを受け取ったときの処理
    def receive(self, text_data):
        text_data_json = json.loads(text_data)
        message = text_data_json['message']
        print("Received message:", message)

    # メッセージ送信関数
    def send_message(self, message):
        while True:
            # 終了フラグが立っていたら終了する
            if not self.sending:
                break
            self.send(text_data=json.dumps({
                'message': message,
            }))
            time.sleep(1)

mysite/templates/mysite/index.html にこれと通信するためのHTMLとJavaScriptを書きます。

<!doctype html>
<html>
  <head>
    <title>test</title>
    <script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
  </head>
  <body>
    <div id="app">
      <ul >
        <li v-for="message in messages">[[ message ]]</li>
      </ul>
    </div>
    <script>
      var vm = new Vue({
        el: '#app',
        // Vueの記号がDjangoと被らないようにする
        delimiters: ['[[', ']]'],
        data: {
          messages: [],
          ws: new WebSocket(
            // http or https の判定
            // https://www.koatech.info/blog/vue-websocket-sample/
            (window.location.protocol == "https:" ? "wss" : "ws")
            + '://' + window.location.host + '/ws/test/'
          ),
        },
        created: function() {
          // Vueオブジェクトにアクセスするために変数化する
          // https://www.koatech.info/blog/vue-websocket-sample/
          const self = this;

          // メッセージを受け取ったときの処理
          self.ws.onmessage = function(e) {
            const data = JSON.parse(e.data);
            console.log(data.message)
            console.log(vm.messages)
            // メッセージを表示する
            vm.messages.push(data.message);
            // サーバーにメッセージを送る
            self.ws.send(JSON.stringify({
              'message': "Konnichiwa",
            }));
          };

          // 接続が切断されたときの処理
          self.ws.onclose = function(e) {
            console.error('Websocket has been closed unexpectedly');
          };
        }
      })
    </script>
  </body>
</html>

mysite/views.py を次のように書き換えてindex.htmlを表示するようにします。

from django.shortcuts import render

def index(request):
    return render(request, 'mysite/index.html')

最後に、テンプレートを認識させるために mysite/settings.pyINSTALLED_APPSmysite を追加します。

INSTALLED_APPS = [
    ......
    'channels',
    'mysite',
]

以上の変更を行ったものをデプロイしてブラウザで閲覧すると1秒ごとにHelloが追加される様子が確認できます。

f:id:kivantium:20200502150709g:plain
スクリーンショット

自分の目的にはこれで十分でしたが、チャットなどのより高度なことをしたい場合はChennelsのドキュメントを読んで下さい。 channels.readthedocs.io

Django ChennelsアプリをNginxとSupervisorでデプロイする

DjangoでWebsocketを使うときにはChannelsというライブラリがよく使われています。これまではHerokuにデプロイをしてきましたが、HerokuとChannelsの相性が良くないのかすぐに接続が切れてしまうので、これからはAWS上で開発しようと思いました。公式ドキュメントを読んでもデプロイ方法がよく分からなかったのでメモしておきます。

AWS LightsailでUbuntu 18.04のインスタンスを立てたとして、SSHで入ってからHello, world!するところまでを見ていきます。

ライブラリのインストール

Daphneの起動を楽にするためにvenvを使います。(参考: Djangoのインストール · Django Girls Tutorial

ssh ubuntu@xxx.xxx.xxx.xxx
sudo apt update
sudo apt install python3-venv
python3 -m venv env
source env/bin/activate

requirements.txtに以下を記述します。

django~=3.0.5
channels~=2.4.0

pipをアップデートしてから必要なライブラリをインストールします。

python -m pip install --upgrade pip
pip install -r requirements.txt

Hello, world!アプリの設定

Django プロジェクトを作成します。

django-admin startproject mysite
cd mysite

基本的にはChannels公式ドキュメントのInstallationに従って設定します。簡単のために本番環境と開発環境の設定の分離などは無視します。

mysite/settings.py を以下のように編集します。diffを示しています。

-ALLOWED_HOSTS = []
+ALLOWED_HOSTS = ['*']  # 本当は適切なホストを指定するべきだが簡単のため全て許可
 
(略)

INSTALLED_APPS = [
     'django.contrib.sessions',
     'django.contrib.messages',
     'django.contrib.staticfiles',
+    'channels',
 ]
 
+ASGI_APPLICATION = "mysite.routing.application"
+
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
     'django.contrib.sessions.middleware.SessionMiddleware',

mysite/routing.py を以下の内容で作成します。

from channels.routing import ProtocolTypeRouter

application = ProtocolTypeRouter({
    # Empty for now (http->django views is added by default)
})

mysite/urls.py を以下の内容で作成します。

from django.urls import path

from . import views

urlpatterns = [
    path('', views.index, name='index'),
]

mysite/views.py を以下の内容で作成します。

from django.http import HttpResponse


def index(request):
    return HttpResponse("Hello, world!")

ここまで来たら python manage.py runserver を実行してエラーが出ないことだけ確認します。(サーバー上にあるのでこの時点ではブラウザで表示確認ができません)

NginxとSupervisorの設定

ここもChannels公式ドキュメントのDeployingに従って設定するだけなのですが、この通りにやっても動かなかったので以下のStackOverflowに従ってアレンジしました。

stackoverflow.com

まずはNginxとSupervisorをインストールします。

sudo apt install nginx supervisor

mysite/asgi.pyを以下の内容で作成します。

"""
ASGI entrypoint. Configures Django and then runs the application
defined in the ASGI_APPLICATION setting.
"""

import os
import django
from channels.routing import get_default_application

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings")
django.setup()
application = get_default_application()

/etc/supervisor/conf.d/asgi.confを以下の内容で作成します。

[fcgi-program:asgi]
# TCP socket used by Nginx backend upstream
socket=tcp://localhost:8000

# Directory where your site's project files are located
directory=/home/ubuntu/mysite

# Each process needs to have a separate socket file, so we use process_num
# Make sure to update "mysite.asgi" to match your project name
command=/home/ubuntu/env/bin/daphne --fd 0 --access-log - --proxy-headers mysite.asgi:application
# Number of processes to startup, roughly the number of CPUs you have
numprocs=4

# Give each process a unique name so they can be told apart
process_name=asgi%(process_num)d

# Automatically start and recover processes
autostart=true
autorestart=true

# Choose where you want your log to go
stdout_logfile=/var/log/asgi.log
redirect_stderr=true

設定を読み込みます。

sudo supervisorctl reread
sudo supervisorctl update

/etc/nginx/sites-available/defaultを以下のように編集します。

upstream channels-backend {
    server localhost:8000;
}
...
server {
    ...
    location / {
        try_files $uri @proxy_to_app;
    }
    ...
    location @proxy_to_app {
        proxy_pass http://channels-backend;

        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        proxy_redirect off;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Host $server_name;
    }
    ...
}

設定を読み込みます。

sudo service nginx reload

ブラウザのアドレス欄にこのサーバーのIPアドレスを入力すれば、Hello, world! と出力されたページを確認することができます。 次回はこのサーバーを使って簡単なWebsocketを使ったプログラムを書きます。

2020年4月

4月は死んだ、もういない!

出来事

インターンについて

まともに労働したのが学部4年の夏休みにインターンしたとき以来だったので記憶が薄れていたのですが、やっぱり労働は厳しいです。時給制で働く場合はどれだけ頑張っても働いた時間に対してしか賃金が支払われないので、頑張らずにゆっくり働くほうが自分にとって得になってしまいます。また、昼寝などをして一時的に職場を離れたほうが結果的に一日のパフォーマンスが上がる状況であっても、昼寝に対しては賃金が発生しません。その結果、眠い目をこすりながら動かない頭で働き続けることになり、非効率な労働をだらだらと続けることになります。非効率な行動を取ることが経済的に得であるという状況が発生してしまう労働という構造が正しくないことに気づいた後、その状況に甘んじている自分が嫌になるというのが労働が嫌になるときのパターンな気がします。業務委託契約で働いていたときは、仕様を満たすものが作れないと何も賃金がもらえないというプレッシャーがある一方、最低限のものを作って納品すればいいだろうという雑な態度になっている自分が嫌になっていた記憶があります。

業務内容が面白くても労働という構造をはさんだ瞬間につらさが発生するのですが、業務内容が面白くない場合はなおさらつらくなります。今回のインターンの目的はロボットの開発技術を学ぶことだったのですが、与えられた仕事はその会社で開発しているソフトの上でロボットの動かし方を考えるというもので、退社した瞬間に役に立たなくなりそうな技術ばかり学んでいます。さらに、コロナウィルスの影響で出社3日目から在宅勤務になっているため、自分が考えたロボットの動かし方を実機で試すこともできず、本当に何をやっているのか分からない状態になっています。だからといって別の仕事をよこせと言える状況でもないので毎日のように退職を考えています。SHIROBAKO 9話でこのままタイヤを作る仕事を続けていいのか悩むみーちゃんの心境がとてもよく分かるようになりました。

Webサービス開発について

このまま労働を続けていても何にもならないと思ったので、業務後の時間を使ってWebサービス開発の勉強をすることにしました。題材は何でも良かったのですが、自分が欲しいサービスを作るのが良いだろうということでTwitter上の画像検索サービスを作っています。(画像の元ネタ検索に便利だった TwiGaTen というサービスが終了していたのが思いついたきっかけですが、いま確認したら再開していました……。サーベイを怠って失敗するのはいつものことです。)

せっかくなので機械学習と絡めようと思って、タイムラインからイラストを含むツイートのみを選んで表示する機能を最初に実装しました。http://kivantium-playground.herokuapp.com/ で公開していますが、依存しているライブラリがHeroku上での動作に問題があるらしく、すぐ落ちてしまいます。タイムラインから画像を自動収集する機能などを実装したらAWSに移行しようと思っています。いずれはレコメンデーションや類似画像検索等も実装していきたいと思っていますが、そこまでやる気が続くかは分かりません。

読んだ本

インターンで通勤時間が発生したらたくさん本が読めると思っていたのですが、在宅勤務になったので結局あまり読めませんでした。

無能なナナがアニメ化することになりました

進撃の巨人(31) (講談社コミックス)

進撃の巨人(31) (講談社コミックス)

  • 作者:諫山 創
  • 発売日: 2020/04/09
  • メディア: コミック

化合物でもDeep Learningがしたい!

この記事は2017年12月15日に https://kivantium.net/deep-for-chem/ に投稿したものです。 情報が古くなっていますが、まだ参照されているようなので再掲します。

この記事はDeep Learningやっていき Advent Calendar 2017の15日目です。

Deep Learningの威力が有名になったのは画像認識コンテストでの圧勝がきっかけでしたが、今ではDeep Learningはあらゆる分野に応用され始めています。NIPS2017でもMachine Learning for Molecules and Materialsが開催されたように、物質化学における機械学習の存在感が高まりつつあります。この記事ではその一例として化学の研究にDeep Learningが使われている例を紹介していきます。

化学物質の研究に機械学習が使われる主なパターンには

  • 分子を入力するとその分子の性質を出力する
  • 分子の性質を入力するとその性質を持った分子を出力する
  • 分子を入力するとその反応を出力する

の3つがあります。それぞれについて詳しく説明します。

分子から性質を予測する

Deep Learning以前

Deep Learning以前の性質予測では、職人の温かい手作りによる特徴量が使われていました。分子の特徴ベクトルはmolecular fingerprintsと呼ばれます。molecular fingerprintsは化合物の特徴的な一部分(fragmentと呼ばれる)がその分子にあるかどうかを0/1で表したbitを並べて作られます。

(画像はFingerprints in the RDKit p.4より引用)

どのfragmentを用いるのが有効かはデータセット・問題に依存するので様々な種類のfingerprintが提案されてきました。

主なfingerprintを挙げると

などがあります。 fingerprintはRDKitなどのライブラリを使うと簡単に計算できます。(各ソフトで計算できるfingerprintのリスト

このようなfingerprintを使ってSVMやRandomforestなどでその分子がある性質を持つ/持たないを予測する研究がたくさんあります。化学の分野でDeep Learningが大きく注目されるきっかけになったのは、kaggleの薬の活性予測のコンペでHintonらのチームが優勝したことですが、論文を見ると特徴量には上のように設計されたものを使っており、ニューラルネットワークで設計されたものではなかったようです。

graph convolutionの登場

fingerprintの設計にニューラルネットワークが導入されたのが[Duvenaud+, 2015]です。この研究ではcircular fingerprint (上のECFPのこと)をもとにneural graph fingerprint (NFP)を提案しています。以下にアルゴリズムを示します。

従来のfingerprint設計でhashやmodになっていた部分が重みを調整できる演算に変更されています。これにより、予測にとって重要なfragmentの寄与は大きく、重要ではないfragmentの寄与は小さくなるような特徴量が設計できるようになりました。実際に分子の水への溶けやすさをNFPで予測したところ水への溶けやすさに影響するR-OHのような構造の重みが大きくなったことが報告されています。

NFPの他にも分子のグラフ構造に基づいたニューラルネットワークベースの特徴量設計の研究が行われています。これらはグラフ構造に注目したニューラルネットワークなので総称としてgraph convolutionと呼ばれています。一番有名なのはGoogle BrainのNeural Message Passing for Quantum Chemistryでしょう。この論文ではMessage Passing Neural Network (MPNN) というグラフ上のニューラルネットワークを提案し、分子のニューラルネットワークの先行研究の多くがMPNNで一般的に記述できることを主張した上で、MPNNが分子の性質を予測する上で高い性能を発揮すると主張しています。 MPNNは

という式で表されます。グラフ上で隣接するエッジからのメッセージ Mを足し合わせるような処理をしていることが分かります。 M, Uなどをうまく定めることで各種のgraph convolutionを表すことができます。詳細は論文を読んでください。Google ResearchによるブログPredicting Properties of Molecules with Machine Learningも役立つかもしれません。(ちなみにこの論文のラストオーサーは先述したkaggleコンペの論文のファーストオーサーです)

graph convolutionのイメージとしてよく使われる絵が[Han Altae-Tran+, 2017]にあります。一枚引用します。

Graph convolutionではない方法としては[Goh+, 2017]のような分子を画像にしてCNNで予測するようなものもあります。

ちなみに、同じ人がつい先日SMILES2Vecという文字列から化合物の性質を予測する論文も書いていました。

実装

分子に対するDeep Learningのライブラリで最も有名なのはDeepChemでしょう。DeepChemはTensorFlowでgraph convolutionを実装しています。Graph Convolutions For Tox21などのチュートリアルを読むとだいたい使い方が分かるのではないでしょうか(私も使ったことはないです)。ちなみに、なぜかPong in DeepChem with A3Cのようなチュートリアルもあり何がしたいのか謎です……

また、PFNが最近Chainer Chemistryを公開しました。NFP, GGNN, Weave, SchNetなどのgraph convolution手法が実装されているほか、QM9, Tox21などの有名どころのデータセットを使うコードも揃っており、普段Chainerを使っている人はこれを試してみるのもよいかもしれません。

性質から分子を作る

創薬などの応用においては、「タンパク質Xの動きを抑制する」などの特定の性質を持った分子を作ることが必要になります。化学物質の構造と生物学的な活性の関係のことをQSARと呼びますが、逆に活性から構造を予測する問題をinverse-QSARのように言うことがあります。

分子設計の難しさの一つは、可能な分子の数が非常にたくさんあることです。[Bohacek+, 1996]によれば、C,N,O,Sを30個以下しか持たない分子に限っても 10^{60}種類の分子が存在できるとされています。そのため全探索は不可能なので何らかの効率的な探索法を考える必要があります。

Deep Learning以外の方法

創薬は重要な研究分野なので以前から研究が行われていました。多くの手法は[Nishibata+, 1991][Pierce+, 2004]のように既に知られている部分構造を組み合わせることで分子を設計しています。最近の研究では[Kawai+, 2014]のように構造の組み合わせに遺伝的アルゴリズム構造を使ったり、[Podlewska+, 2017]のように目的関数を機械学習の予測値にしたりするなどの工夫がなされています。

Deep Learningによる方法

分子設計にDeep Leaningを持ち込んだ研究が[Gómez-Bombarelli+, 2016]です。この研究では分子の文字列表現であるSMILES記法をvariational autoencoder (VAE) を用いて実数ベクトルに変換し、ベイズ最適化で最適化したベクトルをSMILESに戻すことで分子を設計しています。この手法の問題点はVAE空間上で最適化ベクトルをSMILESに戻したときに生成される文字列が文法的に正しくないなどの理由で分子と対応しなくなる率が非常に高かったことです。

SMILES記法は、グラフ構造として表される化合物を環を切り開くなどして文字列として表現できるようにしています。OpenSMILES specificationのように文脈自由文法で規定される文法を持っており、文法に従わない文字列は分子を表しません。(なお、文法に従っていても対応する分子が化学的に存在できるかは別の問題です)。例えば下のような図で表される分子のSMILESはO1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5となります。同じ数字はそこで環を形成していることを表し、カッコは分岐を表しています。

文法的に正しくないSMILESの文字列が生成される問題を解決するために、VAEの入出力にSMILESの文字列をそのまま使うのではなくSMILESを生成する文脈自由文法の生成規則列を使うことにしたのが[Kusner+, 2017]のGrammar Variational Autoencoderです。この研究で技術的に面白いところはVAE表現から文字列を生成する際にプッシュダウンオートマトンを考えて、現在スタックの一番上にある文字から選択できない生成規則の確率を0にする工夫を導入しているところです。この工夫により生成される文字列はSMILESの文法的に正しいものに限定することができるためデコードの効率が上がるほか、潜在空間自体もよりよいものになったと主張されています。

これらのアプローチに影響されたのかは分かりませんが、分子の構造を直接設計するのではなく、分子を表すSMILESを生成する研究が盛んに行われています。

  • [Segler+, 2017] はChEMBLのSMILESを学習したLSTMで新しいSMILESを生成しています。また、薬の候補になりそうな分子を入力としたRNNのファインチューニングなども行っています。

分子から反応を予測する

分子からの反応予測には、複数の分子を入力して反応結果を出力するものと、一つの分子を入力してその分子を作るのに必要な反応を予測するものがあります。

反応結果の予測

反応予測をコンピュータで行う試みは1960年代から行われていますが、従来の手法では専門家がルールをたくさん記述することで実現しています。この分野にもDeep Learningの波が来ています。

[Schwaller+, 2017] は反応物のSMILESを入力に生成物のSMILESを出力する言語モデルを用いて反応の予測を行っています。 SMILESによる反応の記述は、反応物の文字列を入力して生成物の文字を出力する処理なので、英語を入力してフランス語を入力する処理に似ていると彼らは考えましたアメリカの特許にある反応のデータベースから入力と出力のペアを作り、seq2seqという翻訳に使われるRNNモデルを適用して反応の予測を行いました。

結果としてtop-1で80%という先行研究を上回る精度の予測ができるようになったと主張されています。

逆合成の予測

目的の化合物を合成するための反応経路を求めることをretrosynthesisといいますが、実際に化合物を生産する上では非常に重要な技術です。この研究でもDeep Learningを使った論文が出ています。

[Segler+, 2017]ではAlphaGoと似た手法でretrosynthesisを行っています。(図は論文のFigure 1) (a)は目的の化合物(図ではIbuprofen)からはじめて分子をばらしていき、全てが既知の入手可能な分子(図では赤で示されている)にまで還元できたら逆合成が完了するというコンセプトを示しています。 (b)は(a)で用いられた既知の反応を示しています。 (c)は(a)の結果得られた反応経路から実際に目的の化合物を合成する過程を示しています。 (d)がこの論文の中心となるアイデアを表しています。現在の分子をばらすのに使える既知の反応はいくつもあります。反応の各段階を一つの状態ととらえると反応はグラフ上の状態遷移と考えることができ、逆合成はグラフ上の最適な経路を探す問題と解釈できます。そこでゲームの状態を表す木から最適な手を探すのと同じような方法を用いて、最適な次の反応を選ぶことで逆合成を解くことができると考えられます。 (e)のように分子の状態を入力すると良さそうな反応を返すDNNの確率をガイドにしたモンテカルロ木探索を実行することで逆合成を行うことができそうです。

論文の実験ではモンテカルロ木探索を用いた提案手法が先行研究よりも高い性能を示したと主張されています。

私が知っている主な研究はこれくらいですが、他にも面白い研究を知っている方がいらっしゃったらコメントなどで教えて下さい。

PyTorchでファインチューニングしたモデルをONNXで利用する

昨日の作業の結果、Illustration2Vecのモデルが大きすぎて貧弱なサーバーでは使えないことが分かりました。今のところ二次元画像判別器の特徴量抽出にしか使っていないので、もっと軽いモデルでも代用できるはずです。軽いモデルとして有名なSqueezenetをこれまで集めたデータでファインチューニングして様子を見てみることにします。

ファインチューニングとONNXへのエキスポート

PyTorchのチュートリアルが丁寧に説明してくれているので、これをコピペして継ぎ接ぎするだけです。

継ぎ接ぎしたものがこちらになります。これを実行するとmodel.onnxというファイルが作成されます。 ONNX版Illustration2Vecのモデルサイズが910Mに対して、このモデルは2.8MBなのでだいぶ小さくなりました。精度もだいたい同じくらいだと思います。

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
data_dir = "./data/images"

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 1

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
print(model_ft)

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

# Save PyTorch model to file
torch.save(model_ft.to('cpu').state_dict(), 'model.pth')

# Input to the model
x = torch.randn(1, 3, 224, 224, requires_grad=True)
torch_out = model_ft(x)

# Export the model
torch.onnx.export(model_ft,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",                # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})

ONNXモデルの利用

こうして作成したONNXモデルをPyTorchを使わずに利用するコードはこんな感じです。

import os
from PIL import Image

import numpy as np
import onnxruntime

# 中心を正方形に切り抜いてリサイズ
def crop_and_resize(img, size):
    width, height = img.size
    crop_size = min(width, height)
    img_crop = img.crop(((width - crop_size) // 2, (height - crop_size) // 2,
                         (width + crop_size) // 2, (height + crop_size) // 2))
    return img_crop.resize((size, size))

img_mean = np.asarray([0.485, 0.456, 0.406])
img_std = np.asarray([0.229, 0.224, 0.225])

ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

img = Image.open('image.jpg').convert('RGB')
img = crop_and_resize(img, 224)

# 画像の正規化
img_np = np.asarray(img).astype(np.float32)/255.0
img_np_normalized = (img_np - img_mean) / img_std

# (H, W, C) -> (C, H, W)
img_np_transposed = img_np_normalized.transpose(2, 0, 1)

batch_img = [img_np_transposed]

ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
ort_outs = ort_session.run(None, ort_inputs)[0]
batch_result = np.argmax(ort_outs, axis=1)
print(batch_result)

このSqueezenetモデルを使って昨日と同じようなことをするviews.pyが以下のようになります。全文はGitHubを見てください。 github.com

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import numpy as np
import onnxruntime
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def index(request):
    if request.user.is_authenticated:
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        timeline = api.home_timeline(count=200, tweet_mode='extended')

        tweet_media = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                tweet_media.append(tweet)

        batch_size = 4
        tweet_illust = []
        for batch_tweet in more_itertools.chunked(tweet_media, batch_size):
            batch_img = []
            for tweet in batch_tweet:
                media_url = tweet.extended_entities['media'][0]['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(
                    os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename).convert('RGB')
                img = data_transforms(img)
                batch_img.append(to_numpy(img))

            ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
            ort_outs = ort_session.run(None, ort_inputs)[0]
            batch_result = np.argmax(ort_outs, axis=1)
            for tweet, result in zip(batch_tweet, batch_result):
                if result == 1:
                    media_url = tweet.extended_entities['media'][0]['media_url']
                    if hasattr(tweet, "retweeted_status"):
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    text = re.sub(
                        r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str,
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request, 'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request, 'hello/index.html')

モデルがだいぶ小さくなったので、貧弱サーバーでも動かすことができました。

デプロイ

このモデルサイズならHerokuで動かせると思ったのですが、torchvisionなどの依存ライブラリの容量だけでHerokuの500MBの制限を超えてしまうようなので自分のサーバーで動かすことにしました。一応動いてはいるのですが、コールバックの設定がうまくいかないので後日直します。

2.9MBのモデルなら一瞬で推論できると思ったのですが、それでもまだ貧弱サーバーには荷が重いようで読み込みにだいぶ時間がかかります。もっと軽いモデルを作るかディープラーニングに頼らない方法を考えるのが良さそうです。

PyTorchのCPU版をrequirements.txtで指定すればHerokuにデプロイできました。

https://kivantium-playground.herokuapp.com/ から試すことができます。(開発状況によっては違うものがデプロイされているかもしれません)

タイムラインから二次元イラストだけを表示するWebアプリの作成

ここまでの成果を使って、タイムラインから二次元イラストだけを表示するTwitterクライアントっぽいWebアプリを作成します。

スクリーンショット

出来上がったものがこちらになります。

f:id:kivantium:20200423234030p:plain:w600
スクリーンショット

以下、コードと今後の課題を述べます。

コード

コード全文はGitHubを見てください。 github.com

簡単のため、ログイン済みユーザーがアクセスするたびにタイムラインから最新のツイート200件を読み込んで、二次元画像判別器が二次元イラストだと判定した画像つきツイートを表示することにしました。200件以上のツイートを同時に読み込むのはTwitter APIの制限上難しかったです。 リツイートに関しては、リツイートした人の情報ではなくリツイート元の情報を表示することにします。英語で280文字までツイートできるようにする最近の仕様変更に対応するために少し面倒な処理を行っています。(参照: Extended Tweets — tweepy 3.8.0 documentation

前回作成したRandom Forestによる判定器や、ONNX版Illustration2Vecをhello/以下に置いています。

hello/views.py

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import sys
sys.path.append(os.path.dirname(__file__))
import i2v

# ONNX版Illustration2Vec
illust2vec = i2v.make_i2v_with_onnx(os.path.join(os.path.dirname(__file__), "illust2vec_ver200.onnx"))

# 事前に作成しておいた二次元画像判別器
clf = load(os.path.join(os.path.dirname(__file__), "clf.joblib"))

def index(request):
    if request.user.is_authenticated:  # Twitterでログインしている場合
        # ユーザー情報の取得
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        # 全文を取得するためにextendedを指定する
        timeline = api.home_timeline(count=200, tweet_mode = 'extended')

        tweet_illust = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                media  = tweet.extended_entities['media'][0]
                media_url = media['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename)
                feature = illust2vec.extract_feature([img])
                prob = clf.predict_proba(feature)[0]
                if prob[1] > 0.4:  # 二次元イラストの可能性が高い
                    if hasattr(tweet, "retweeted_status"): 
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    # リツイート元のツイート全文の取得
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    # 画像URLを削除するために文末のURLを削除する
                    text = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str, 
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        # 表示の都合上4つずつのグループに分ける
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request,'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request,'hello/index.html')

これを表示するためのHTMLを示します。Bulmaで画像の中央を丸く切り取って並べる - kivantium活動日記の応用です。Bulmaのカード機能を使っています。 bulma.io

<!doctype html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>にじさーち</title>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.8.0/css/bulma.min.css">
    <style>
img { object-fit: cover; }
.bm--card-equal-height {
   display: flex;
   flex-direction: column;
   height: 100%;
}
.bm--card-equal-height .card-footer {
   margin-top: auto;
} </style>
  </head>
  <body>
  <nav class="navbar is-primary">
    <div class="navbar-brand navbar-item">
      <h1 class="title has-text-light">にじさーち</h1>
    </div>
    {% if request.user.is_authenticated %}
    <div class="navbar-end">
      <div class="navbar-item">
      <a class="button is-light" href="/logout">ログアウト</a>
      </div>
    {% endif %}
    </div>
  </nav>
  <section class="section">
    <div class="container">
      {% if request.user.is_authenticated %}
      {% for tweets in timeline_chunked %}
      <div class="columns is-mobile">
        {% for tweet in tweets %}
        <div class="column is-3">
          <div class="card bm--card-equal-height">
            <div class="card-content">
              <div class="media">
                <div class="media-left">
                  <figure class="image is-48x48">
                    <img src="{{ tweet.profile_image_url }}" alt="Profile image">
                  </figure>
                </div>
                <div class="media-content">
                  <p class="title is-4">{{ tweet.author.name }}</p>
                  <p class="subtitle is-6">@{{ tweet.author.screen_name }}</p>
                </div>
              </div>
              <div class="card-image">
                <figure class="image is-square">
                  <a href="https://twitter.com/i/web/status/{{ tweet.id_str }}" target="_blank" rel="noopener noreferrer">
                    <img src="{{ tweet.image_url }}" alt="main image">
                  </a>
                </figure>
              </div>
              <div class="content">{{ tweet.text }}</div>
            </div>
          </div>
        </div>
        {% endfor %}
      </div>
      {% endfor %}
      {% else %}
      <p>あなたはログインしていません</p>
      <button type="button" onclick="location.href='{% url 'social:begin' 'twitter' %}'">Twitterでログイン</button>
      {% endif %}
    </div>
  </section>
  </body>
</html>

今後の課題

Illustration2Vecのモデルが重い

今回作成したアプリをサーバーにデプロイしようと思ったのですが、Illustration2Vecのモデルがサーバーのメモリサイズよりも大きかったためデプロイすることができませんでした。また、今後複数のユーザーによる使用をサポートしようとするとアクセスが来るたびにIllustration2Vecを実行していてはとても追いつかないのでモデルを軽量化することが必要になりそうです。

画像データベースとしての利用

Twitter APIのRate Limitが厳しいため、タイムラインから一度に収集できるツイートは200件くらいしかありません。これでは大量の画像を閲覧する目的に向きません。そのため、status/filterで常に画像を収集しておき画像データベースとして利用することが考えられます。しかし、類似のサービスが(利用規約に則っているにも関わらず)以前大炎上したことがあるっぽいので、Twitterクライアントとして一般に認められる以上の機能を提供するとなると面倒くさそうです。 nlab.itmedia.co.jp

自動タグ付け

Illustration2Vecでもタグ付けを行うことができますが、つけることができるタグの種類は有限です。新しく増える作品やキャラに対応するために何らかの方法で類似画像のハッシュタグからタグを推定して自動タグ付けができるとよさそうです。(これも絵師界隈の自主ルールで難癖つけられて面倒なことになりそうですが……)

Display requirementsへの適合

利用規約に則ってTwitterのコンテンツを表示する際の条件としてDisplay Requirementsというものがあります。 developer.twitter.com

ツイート本文を全文表示しないといけないとか、Twitterのロゴを右上に表示しないといけないなどといったユーザーの利便性を損なう規定なのですが、規定なので従う必要があります。 Google画像検索のようにサムネイルだけ表示してあとはツイートへのリンクにする方式にすることも含めて検討していきたいです。

二次元画像判別器に対するActive Learning導入の検討

前回の記事では、Twitter上の画像から二次元画像を選ぼうとすると二次元とも三次元とも言い難い画像が入ってくる問題があることを見ました。今回は、Active Learningという方法を使って境界領域の画像をうまく扱う方法を適用したいと思います。

Active Learningについて

Active Learningという言葉は教育業界と機械学習業界の両方で使われているので混乱がありますが、ここでは機械学習でのActive learningを指します。通常の機械学習の問題設定では学習データは既に与えられたものとして扱うことが多いですが、Active Leaningではどのデータを学習するかを選ぶことができるという設定のもとで学習を行います。これにより、少ないデータ数で学習が行えるようになることが期待できます。

f:id:kivantium:20200418181517p:plain:w600
Active Learningでは、境界に近いデータを能動的に選ぶことで効率的に学習を行うことを目指す。
ICML 2019のActive Learningチュートリアルのスライドより。)

以下、Active Learning Literature Surveyの内容に沿って話を進めます。

Active Learningの主なシナリオには3つあります。

  • Membership Query Synthesis: 学習器が入力空間中の任意のラベルなしインスタンスについてラベル付けを要求できる(新しく生成したインスタンスでも良い)
  • Stream-Based Selective Sampling: 1つずつ流れてくるデータそれぞれについてラベルを要求するか破棄するかを決める
  • Pool-Based Sampling: ラベル付きデータとラベルなしデータが与えられ、ラベルなしデータの中からどのデータにラベル付けを要求するか決める

どのデータに対してラベルを要求するかを決定する基準として最もよく使われているのがUncertainty Samplingという方式で、主なものが3種類あります。

  • least confident: 一番確信度が低いものを選ぶ。数式で書くと、 1 − P(\hat{y}|x) が最大のものを選ぶ(\hat{y} = \mathrm{arg max}_y P(y|x))。
  • margin sampling: 一番可能性が高いクラスと二番目に可能性が高いクラスの分類確率の差が一番小さいものを選ぶ。数式で書くと、 P(\hat{y}_1|x) − P(\hat{y}_2|x) が最小のものを選ぶ。
  • entropy: エントロピーが最大のものを選ぶ。数式で書くと、 -\sum_{i} P(y_i|x) \log{P(y_i|x)}が最大のものを選ぶ。

二次元画像判別に対する応用

今回は、ラベル付けを行った画像とラベルがついていない画像が与えられているのでPool-Based Samplingのシナリオになります。とりあえず一番簡単そうなmargin samplingを使って、昨日ラベル付けをサボったデータに対してActive Learningをやってみようと思ったのですが、1個ずつラベル付けするのは面倒なので、分類確率の差が0.3より小さいデータがどんな感じのデータになるのかを見てみることにします。

import os
import shutil

import numpy as np
from PIL import Image
import more_itertools
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from tqdm import tqdm

import i2v

illust2vec = i2v.make_i2v_with_onnx("illust2vec_ver200.onnx")

# 学習データの準備
X = []
y = []
batch_size = 4

negative_path = '0'
negative_list = os.listdir(negative_path)
for batch in tqdm(list(more_itertools.chunked(negative_list, batch_size))):
    img_list = [Image.open(os.path.join(negative_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([0] * len(batch))

positive_path = '1'
positive_list = os.listdir(positive_path)
for batch in tqdm(list(more_itertools.chunked(positive_list, batch_size))):
    img_list = [Image.open(os.path.join(positive_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([1] * len(batch))

# Random Forestの学習
clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

# Unlabeled データをフォルダ分けする
pool_path = 'unlabeled'
pool_list = os.listdir(pool_path)
for filename in pool_list:
    filename = os.path.join(pool_path, filename)
    img = Image.open(filename)
    feature = illust2vec.extract_feature([img])
    prob = clf.predict_proba(feature)[0]
    # 確率値の差が0.3以下ならラベル付けを要求する
    if np.abs(prob[0]-prob[1]) < 0.3:
        shutil.move(filename, 'uncertain')
    elif prob[0] > prob[1]:
        shutil.move(filename, 'negative')
    else:
        shutil.move(filename, 'positive')

Unlabeledデータ2021枚のうち、uncertainに分類されたものが193枚、negativeに分類されたものが1783枚、positiveに分類されたものが52枚でした。

f:id:kivantium:20200418191154p:plain:w600
紛らわしいと判定された画像

uncertainに分類された画像をさらに詳しく見てみました。

前回の記事で述べた紛らわしい種類の画像がきちんとuncertainに分類されており、Random Forestによる分類確率が紛らわしさをきちんと捉えていることが確認できました。

positiveに分類された画像はアニメのスクリーンショット1枚を除いて全てイラストでした。

f:id:kivantium:20200418190731p:plain:w600
二次元イラストだと判定された画像

negativeに分類された画像のうちイラストは34枚でした。これらの画像は、コントラストが薄めである・人間がたくさん書かれているなどの理由から漫画と間違えられた可能性が高いと思っています。(今回のラベリングではコマ割りがあるまたは白黒の画像は全て二次元イラストではないとしています)

f:id:kivantium:20200418190429p:plain:w600
二次元イラストではないと間違えて判定されたイラスト

以上の結果から、margin samplingは二次元画像分類の境界ケースをきちんと集めることができそうだという感触を得ました。これを学習データに加えたら精度が上がったという実験結果を出せればよかったのですが、ランダムサンプリングでも95%くらいの精度が出ていたのでActive Learningで有意差を出すことが難しそうでした。Active Learningをするというよりは、棄却オプションをつけて不確かな画像は人手で分類するようにするのが良さそうです。

次回はこの結果を使って二次元画像だけのタイムラインを表示するアプリを作ろうと思います。

広告コーナー