kivantium活動日記

プログラムを使っていろいろやります

Django ChennelsアプリをNginxとSupervisorでデプロイする

DjangoでWebsocketを使うときにはChannelsというライブラリがよく使われています。これまではHerokuにデプロイをしてきましたが、HerokuとChannelsの相性が良くないのかすぐに接続が切れてしまうので、これからはAWS上で開発しようと思いました。公式ドキュメントを読んでもデプロイ方法がよく分からなかったのでメモしておきます。

AWS LightsailでUbuntu 18.04のインスタンスを立てたとして、SSHで入ってからHello, world!するところまでを見ていきます。

ライブラリのインストール

Daphneの起動を楽にするためにvenvを使います。(参考: Djangoのインストール · Django Girls Tutorial

ssh ubuntu@xxx.xxx.xxx.xxx
sudo apt update
sudo apt install python3-venv
python3 -m venv env
source env/bin/activate

requirements.txtに以下を記述します。

django~=3.0.5
channels~=2.4.0

pipをアップデートしてから必要なライブラリをインストールします。

python -m pip install --upgrade pip
pip install -r requirements.txt

Hello, world!アプリの設定

Django プロジェクトを作成します。

django-admin startproject mysite
cd mysite

基本的にはChannels公式ドキュメントのInstallationに従って設定します。簡単のために本番環境と開発環境の設定の分離などは無視します。

mysite/settings.py を以下のように編集します。diffを示しています。

-ALLOWED_HOSTS = []
+ALLOWED_HOSTS = ['*']  # 本当は適切なホストを指定するべきだが簡単のため全て許可
 
(略)

INSTALLED_APPS = [
     'django.contrib.sessions',
     'django.contrib.messages',
     'django.contrib.staticfiles',
+    'channels',
 ]
 
+ASGI_APPLICATION = "mysite.routing.application"
+
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
     'django.contrib.sessions.middleware.SessionMiddleware',

mysite/routing.py を以下の内容で作成します。

from channels.routing import ProtocolTypeRouter

application = ProtocolTypeRouter({
    # Empty for now (http->django views is added by default)
})

mysite/urls.py を以下の内容で作成します。

from django.urls import path

from . import views

urlpatterns = [
    path('', views.index, name='index'),
]

mysite/views.py を以下の内容で作成します。

from django.http import HttpResponse


def index(request):
    return HttpResponse("Hello, world!")

ここまで来たら python manage.py runserver を実行してエラーが出ないことだけ確認します。(サーバー上にあるのでこの時点ではブラウザで表示確認ができません)

NginxとSupervisorの設定

ここもChannels公式ドキュメントのDeployingに従って設定するだけなのですが、この通りにやっても動かなかったので以下のStackOverflowに従ってアレンジしました。

stackoverflow.com

まずはNginxとSupervisorをインストールします。

sudo apt install nginx supervisor

mysite/asgi.pyを以下の内容で作成します。

"""
ASGI entrypoint. Configures Django and then runs the application
defined in the ASGI_APPLICATION setting.
"""

import os
import django
from channels.routing import get_default_application

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings")
django.setup()
application = get_default_application()

/etc/supervisor/conf.d/asgi.confを以下の内容で作成します。

[fcgi-program:asgi]
# TCP socket used by Nginx backend upstream
socket=tcp://localhost:8000

# Directory where your site's project files are located
directory=/home/ubuntu/mysite

# Each process needs to have a separate socket file, so we use process_num
# Make sure to update "mysite.asgi" to match your project name
command=/home/ubuntu/env/bin/daphne --fd 0 --access-log - --proxy-headers mysite.asgi:application
# Number of processes to startup, roughly the number of CPUs you have
numprocs=4

# Give each process a unique name so they can be told apart
process_name=asgi%(process_num)d

# Automatically start and recover processes
autostart=true
autorestart=true

# Choose where you want your log to go
stdout_logfile=/var/log/asgi.log
redirect_stderr=true

設定を読み込みます。

sudo supervisorctl reread
sudo supervisorctl update

/etc/nginx/sites-available/defaultを以下のように編集します。

upstream channels-backend {
    server localhost:8000;
}
...
server {
    ...
    location / {
        try_files $uri @proxy_to_app;
    }
    ...
    location @proxy_to_app {
        proxy_pass http://channels-backend;

        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        proxy_redirect off;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Host $server_name;
    }
    ...
}

設定を読み込みます。

sudo service nginx reload

ブラウザのアドレス欄にこのサーバーのIPアドレスを入力すれば、Hello, world! と出力されたページを確認することができます。 次回はこのサーバーを使って簡単なWebsocketを使ったプログラムを書きます。

2020年4月

4月は死んだ、もういない!

出来事

インターンについて

まともに労働したのが学部4年の夏休みにインターンしたとき以来だったので記憶が薄れていたのですが、やっぱり労働は厳しいです。時給制で働く場合はどれだけ頑張っても働いた時間に対してしか賃金が支払われないので、頑張らずにゆっくり働くほうが自分にとって得になってしまいます。また、昼寝などをして一時的に職場を離れたほうが結果的に一日のパフォーマンスが上がる状況であっても、昼寝に対しては賃金が発生しません。その結果、眠い目をこすりながら動かない頭で働き続けることになり、非効率な労働をだらだらと続けることになります。非効率な行動を取ることが経済的に得であるという状況が発生してしまう労働という構造が正しくないことに気づいた後、その状況に甘んじている自分が嫌になるというのが労働が嫌になるときのパターンな気がします。業務委託契約で働いていたときは、仕様を満たすものが作れないと何も賃金がもらえないというプレッシャーがある一方、最低限のものを作って納品すればいいだろうという雑な態度になっている自分が嫌になっていた記憶があります。

業務内容が面白くても労働という構造をはさんだ瞬間につらさが発生するのですが、業務内容が面白くない場合はなおさらつらくなります。今回のインターンの目的はロボットの開発技術を学ぶことだったのですが、与えられた仕事はその会社で開発しているソフトの上でロボットの動かし方を考えるというもので、退社した瞬間に役に立たなくなりそうな技術ばかり学んでいます。さらに、コロナウィルスの影響で出社3日目から在宅勤務になっているため、自分が考えたロボットの動かし方を実機で試すこともできず、本当に何をやっているのか分からない状態になっています。だからといって別の仕事をよこせと言える状況でもないので毎日のように退職を考えています。SHIROBAKO 9話でこのままタイヤを作る仕事を続けていいのか悩むみーちゃんの心境がとてもよく分かるようになりました。

Webサービス開発について

このまま労働を続けていても何にもならないと思ったので、業務後の時間を使ってWebサービス開発の勉強をすることにしました。題材は何でも良かったのですが、自分が欲しいサービスを作るのが良いだろうということでTwitter上の画像検索サービスを作っています。(画像の元ネタ検索に便利だった TwiGaTen というサービスが終了していたのが思いついたきっかけですが、いま確認したら再開していました……。サーベイを怠って失敗するのはいつものことです。)

せっかくなので機械学習と絡めようと思って、タイムラインからイラストを含むツイートのみを選んで表示する機能を最初に実装しました。http://kivantium-playground.herokuapp.com/ で公開していますが、依存しているライブラリがHeroku上での動作に問題があるらしく、すぐ落ちてしまいます。タイムラインから画像を自動収集する機能などを実装したらAWSに移行しようと思っています。いずれはレコメンデーションや類似画像検索等も実装していきたいと思っていますが、そこまでやる気が続くかは分かりません。

読んだ本

インターンで通勤時間が発生したらたくさん本が読めると思っていたのですが、在宅勤務になったので結局あまり読めませんでした。

無能なナナがアニメ化することになりました

進撃の巨人(31) (講談社コミックス)

進撃の巨人(31) (講談社コミックス)

  • 作者:諫山 創
  • 発売日: 2020/04/09
  • メディア: コミック

化合物でもDeep Learningがしたい!

この記事は2017年12月15日に https://kivantium.net/deep-for-chem/ に投稿したものです。 情報が古くなっていますが、まだ参照されているようなので再掲します。

この記事はDeep Learningやっていき Advent Calendar 2017の15日目です。

Deep Learningの威力が有名になったのは画像認識コンテストでの圧勝がきっかけでしたが、今ではDeep Learningはあらゆる分野に応用され始めています。NIPS2017でもMachine Learning for Molecules and Materialsが開催されたように、物質化学における機械学習の存在感が高まりつつあります。この記事ではその一例として化学の研究にDeep Learningが使われている例を紹介していきます。

化学物質の研究に機械学習が使われる主なパターンには

  • 分子を入力するとその分子の性質を出力する
  • 分子の性質を入力するとその性質を持った分子を出力する
  • 分子を入力するとその反応を出力する

の3つがあります。それぞれについて詳しく説明します。

分子から性質を予測する

Deep Learning以前

Deep Learning以前の性質予測では、職人の温かい手作りによる特徴量が使われていました。分子の特徴ベクトルはmolecular fingerprintsと呼ばれます。molecular fingerprintsは化合物の特徴的な一部分(fragmentと呼ばれる)がその分子にあるかどうかを0/1で表したbitを並べて作られます。 (画像はFingerprints in the RDKit p.4より引用)

どのfragmentを用いるのが有効かはデータセット・問題に依存するので様々な種類のfingerprintが提案されてきました。

主なfingerprintを挙げると

などがあります。 fingerprintはRDKitなどのライブラリを使うと簡単に計算できます。(各ソフトで計算できるfingerprintのリスト

このようなfingerprintを使ってSVMやRandomforestなどでその分子がある性質を持つ/持たないを予測する研究がたくさんあります。化学の分野でDeep Learningが大きく注目されるきっかけになったのは、kaggleの薬の活性予測のコンペでHintonらのチームが優勝したことですが、論文を見ると特徴量には上のように設計されたものを使っており、ニューラルネットワークで設計されたものではなかったようです。

graph convolutionの登場

fingerprintの設計にニューラルネットワークが導入されたのが[Duvenaud+, 2015]です。この研究ではcircular fingerprint (上のECFPのこと)をもとにneural graph fingerprint (NFP)を提案しています。以下にアルゴリズムを示します。 従来のfingerprint設計でhashやmodになっていた部分が重みを調整できる演算に変更されています。これにより、予測にとって重要なfragmentの寄与は大きく、重要ではないfragmentの寄与は小さくなるような特徴量が設計できるようになりました。実際に分子の水への溶けやすさをNFPで予測したところ水への溶けやすさに影響するR-OHのような構造の重みが大きくなったことが報告されています。 NFPの他にも分子のグラフ構造に基づいたニューラルネットワークベースの特徴量設計の研究が行われています。これらはグラフ構造に注目したニューラルネットワークなので総称としてgraph convolutionと呼ばれています。一番有名なのはGoogle BrainのNeural Message Passing for Quantum Chemistryでしょう。この論文ではMessage Passing Neural Network (MPNN) というグラフ上のニューラルネットワークを提案し、分子のニューラルネットワークの先行研究の多くがMPNNで一般的に記述できることを主張した上で、MPNNが分子の性質を予測する上で高い性能を発揮すると主張しています。 MPNNは という式で表されます。グラフ上で隣接するエッジからのメッセージ$M$を足し合わせるような処理をしていることが分かります。$M$,$U$などをうまく定めることで各種のgraph convolutionを表すことができます。詳細は論文を読んでください。Google ResearchによるブログPredicting Properties of Molecules with Machine Learningも役立つかもしれません。(ちなみにこの論文のラストオーサーは先述したkaggleコンペの論文のファーストオーサーです)

graph convolutionのイメージとしてよく使われる絵が[Han Altae-Tran+, 2017]にあります。一枚引用します。

Graph convolutionではない方法としては[Goh+, 2017]のような分子を画像にしてCNNで予測するようなものもあります。

ちなみに、同じ人がつい先日SMILES2Vecという文字列から化合物の性質を予測する論文も書いていました。

実装

分子に対するDeep Learningのライブラリで最も有名なのはDeepChemでしょう。DeepChemはTensorFlowでgraph convolutionを実装しています。Graph Convolutions For Tox21などのチュートリアルを読むとだいたい使い方が分かるのではないでしょうか(私も使ったことはないです)。ちなみに、なぜかPong in DeepChem with A3Cのようなチュートリアルもあり何がしたいのか謎です……

また、PFNが最近Chainer Chemistryを公開しました。NFP, GGNN, Weave, SchNetなどのgraph convolution手法が実装されているほか、QM9, Tox21などの有名どころのデータセットを使うコードも揃っており、普段Chainerを使っている人はこれを試してみるのもよいかもしれません。

性質から分子を作る

創薬などの応用においては、「タンパク質Xの動きを抑制する」などの特定の性質を持った分子を作ることが必要になります。化学物質の構造と生物学的な活性の関係のことをQSARと呼びますが、逆に活性から構造を予測する問題をinverse-QSARのように言うことがあります。

分子設計の難しさの一つは、可能な分子の数が非常にたくさんあることです。[Bohacek+, 1996]1098-1128(199601)16:1%3C3::AID-MED1%3E3.0.CO;2-6/abstract)によれば、C,N,O,Sを30個以下しか持たない分子に限っても$10^{60}$種類の分子が存在できるとされています。そのため全探索は不可能なので何らかの効率的な探索法を考える必要があります。

Deep Learning以外の方法

創薬は重要な研究分野なので以前から研究が行われていました。多くの手法は[Nishibata+, 1991][Pierce+, 2004]のように既に知られている部分構造を組み合わせることで分子を設計しています。最近の研究では[Kawai+, 2014]のように構造の組み合わせに遺伝的アルゴリズム構造を使ったり、[Podlewska+, 2017]のように目的関数を機械学習の予測値にしたりするなどの工夫がなされています。

Deep Learningによる方法

分子設計にDeep Leaningを持ち込んだ研究が[Gómez-Bombarelli+, 2016]です。この研究では分子の文字列表現であるSMILES記法をvariational autoencoder (VAE) を用いて実数ベクトルに変換し、ベイズ最適化で最適化したベクトルをSMILESに戻すことで分子を設計しています。この手法の問題点はVAE空間上で最適化ベクトルをSMILESに戻したときに生成される文字列が文法的に正しくないなどの理由で分子と対応しなくなる率が非常に高かったことです。

SMILES記法は、グラフ構造として表される化合物を環を切り開くなどして文字列として表現できるようにしています。OpenSMILES specificationのように文脈自由文法で規定される文法を持っており、文法に従わない文字列は分子を表しません。(なお、文法に従っていても対応する分子が化学的に存在できるかは別の問題です)。例えば下のような図で表される分子のSMILESはO1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5となります。同じ数字はそこで環を形成していることを表し、カッコは分岐を表しています。

文法的に正しくないSMILESの文字列が生成される問題を解決するために、VAEの入出力にSMILESの文字列をそのまま使うのではなくSMILESを生成する文脈自由文法の生成規則列を使うことにしたのが[Kusner+, 2017]のGrammar Variational Autoencoderです。この研究で技術的に面白いところはVAE表現から文字列を生成する際にプッシュダウンオートマトンを考えて、現在スタックの一番上にある文字から選択できない生成規則の確率を0にする工夫を導入しているところです。この工夫により生成される文字列はSMILESの文法的に正しいものに限定することができるためデコードの効率が上がるほか、潜在空間自体もよりよいものになったと主張されています。

これらのアプローチに影響されたのかは分かりませんが、分子の構造を直接設計するのではなく、分子を表すSMILESを生成する研究が盛んに行われています。

  • [Segler+, 2017] はChEMBLのSMILESを学習したLSTMで新しいSMILESを生成しています。また、薬の候補になりそうな分子を入力としたRNNのファインチューニングなども行っています。

  • [Guimaraes+, 2017] はGANに強化学習を組み込むことで偏った性質を持つ分子のSMILESを生成しています

  • [Yang+, 2017] ではモンテカルロ木探索とRNNを組み合わせることで分子の設計を行っています。

分子から反応を予測する

分子からの反応予測には、複数の分子を入力して反応結果を出力するものと、一つの分子を入力してその分子を作るのに必要な反応を予測するものがあります。

反応結果の予測

反応予測をコンピュータで行う試みは1960年代から行われていますが、従来の手法では専門家がルールをたくさん記述することで実現しています。この分野にもDeep Learningの波が来ています。

[Schwaller+, 2017] は反応物のSMILESを入力に生成物のSMILESを出力する言語モデルを用いて反応の予測を行っています。 SMILESによる反応の記述は、反応物の文字列を入力して生成物の文字を出力する処理なので、英語を入力してフランス語を入力する処理に似ていると彼らは考えましたアメリカの特許にある反応のデータベースから入力と出力のペアを作り、seq2seqという翻訳に使われるRNNモデルを適用して反応の予測を行いました。 結果としてtop-1で80%という先行研究を上回る精度の予測ができるようになったと主張されています。

逆合成の予測

目的の化合物を合成するための反応経路を求めることをretrosynthesisといいますが、実際に化合物を生産する上では非常に重要な技術です。この研究でもDeep Learningを使った論文が出ています。

[Segler+, 2017]ではAlphaGoと似た手法でretrosynthesisを行っています。(図は論文のFigure 1) (a)は目的の化合物(図ではIbuprofen)からはじめて分子をばらしていき、全てが既知の入手可能な分子(図では赤で示されている)にまで還元できたら逆合成が完了するというコンセプトを示しています。 (b)は(a)で用いられた既知の反応を示しています。 (c)は(a)の結果得られた反応経路から実際に目的の化合物を合成する過程を示しています。 (d)がこの論文の中心となるアイデアを表しています。現在の分子をばらすのに使える既知の反応はいくつもあります。反応の各段階を一つの状態ととらえると反応はグラフ上の状態遷移と考えることができ、逆合成はグラフ上の最適な経路を探す問題と解釈できます。そこでゲームの状態を表す木から最適な手を探すのと同じような方法を用いて、最適な次の反応を選ぶことで逆合成を解くことができると考えられます。 (e)のように分子の状態を入力すると良さそうな反応を返すDNNの確率をガイドにしたモンテカルロ木探索を実行することで逆合成を行うことができそうです。

論文の実験ではモンテカルロ木探索を用いた提案手法が先行研究よりも高い性能を示したと主張されています。

私が知っている主な研究はこれくらいですが、他にも面白い研究を知っている方がいらっしゃったらコメントなどで教えて下さい。

PyTorchでファインチューニングしたモデルをONNXで利用する

昨日の作業の結果、Illustration2Vecのモデルが大きすぎて貧弱なサーバーでは使えないことが分かりました。今のところ二次元画像判別器の特徴量抽出にしか使っていないので、もっと軽いモデルでも代用できるはずです。軽いモデルとして有名なSqueezenetをこれまで集めたデータでファインチューニングして様子を見てみることにします。

ファインチューニングとONNXへのエキスポート

PyTorchのチュートリアルが丁寧に説明してくれているので、これをコピペして継ぎ接ぎするだけです。

継ぎ接ぎしたものがこちらになります。これを実行するとmodel.onnxというファイルが作成されます。 ONNX版Illustration2Vecのモデルサイズが910Mに対して、このモデルは2.8MBなのでだいぶ小さくなりました。精度もだいたい同じくらいだと思います。

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
data_dir = "./data/images"

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 1

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
print(model_ft)

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

# Save PyTorch model to file
torch.save(model_ft.to('cpu').state_dict(), 'model.pth')

# Input to the model
x = torch.randn(1, 3, 224, 224, requires_grad=True)
torch_out = model_ft(x)

# Export the model
torch.onnx.export(model_ft,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",                # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})

ONNXモデルの利用

こうして作成したONNXモデルをPyTorchを使わずに利用するコードはこんな感じです。

import os
from PIL import Image

import numpy as np
import onnxruntime

# 中心を正方形に切り抜いてリサイズ
def crop_and_resize(img, size):
    width, height = img.size
    crop_size = min(width, height)
    img_crop = img.crop(((width - crop_size) // 2, (height - crop_size) // 2,
                         (width + crop_size) // 2, (height + crop_size) // 2))
    return img_crop.resize((size, size))

img_mean = np.asarray([0.485, 0.456, 0.406])
img_std = np.asarray([0.229, 0.224, 0.225])

ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

img = Image.open('image.jpg').convert('RGB')
img = crop_and_resize(img, 224)

# 画像の正規化
img_np = np.asarray(img).astype(np.float32)/255.0
img_np_normalized = (img_np - img_mean) / img_std

# (H, W, C) -> (C, H, W)
img_np_transposed = img_np_normalized.transpose(2, 0, 1)

batch_img = [img_np_transposed]

ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
ort_outs = ort_session.run(None, ort_inputs)[0]
batch_result = np.argmax(ort_outs, axis=1)
print(batch_result)

このSqueezenetモデルを使って昨日と同じようなことをするviews.pyが以下のようになります。全文はGitHubを見てください。 github.com

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import numpy as np
import onnxruntime
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


ort_session = onnxruntime.InferenceSession(
    os.path.join(os.path.dirname(__file__), "model.onnx"))

data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def index(request):
    if request.user.is_authenticated:
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        timeline = api.home_timeline(count=200, tweet_mode='extended')

        tweet_media = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                tweet_media.append(tweet)

        batch_size = 4
        tweet_illust = []
        for batch_tweet in more_itertools.chunked(tweet_media, batch_size):
            batch_img = []
            for tweet in batch_tweet:
                media_url = tweet.extended_entities['media'][0]['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(
                    os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename).convert('RGB')
                img = data_transforms(img)
                batch_img.append(to_numpy(img))

            ort_inputs = {ort_session.get_inputs()[0].name: batch_img}
            ort_outs = ort_session.run(None, ort_inputs)[0]
            batch_result = np.argmax(ort_outs, axis=1)
            for tweet, result in zip(batch_tweet, batch_result):
                if result == 1:
                    media_url = tweet.extended_entities['media'][0]['media_url']
                    if hasattr(tweet, "retweeted_status"):
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    text = re.sub(
                        r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str,
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request, 'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request, 'hello/index.html')

モデルがだいぶ小さくなったので、貧弱サーバーでも動かすことができました。

デプロイ

このモデルサイズならHerokuで動かせると思ったのですが、torchvisionなどの依存ライブラリの容量だけでHerokuの500MBの制限を超えてしまうようなので自分のサーバーで動かすことにしました。一応動いてはいるのですが、コールバックの設定がうまくいかないので後日直します。

2.9MBのモデルなら一瞬で推論できると思ったのですが、それでもまだ貧弱サーバーには荷が重いようで読み込みにだいぶ時間がかかります。もっと軽いモデルを作るかディープラーニングに頼らない方法を考えるのが良さそうです。

PyTorchのCPU版をrequirements.txtで指定すればHerokuにデプロイできました。

https://kivantium-playground.herokuapp.com/ から試すことができます。(開発状況によっては違うものがデプロイされているかもしれません)

タイムラインから二次元イラストだけを表示するWebアプリの作成

ここまでの成果を使って、タイムラインから二次元イラストだけを表示するTwitterクライアントっぽいWebアプリを作成します。

スクリーンショット

出来上がったものがこちらになります。

f:id:kivantium:20200423234030p:plain:w600
スクリーンショット

以下、コードと今後の課題を述べます。

コード

コード全文はGitHubを見てください。 github.com

簡単のため、ログイン済みユーザーがアクセスするたびにタイムラインから最新のツイート200件を読み込んで、二次元画像判別器が二次元イラストだと判定した画像つきツイートを表示することにしました。200件以上のツイートを同時に読み込むのはTwitter APIの制限上難しかったです。 リツイートに関しては、リツイートした人の情報ではなくリツイート元の情報を表示することにします。英語で280文字までツイートできるようにする最近の仕様変更に対応するために少し面倒な処理を行っています。(参照: Extended Tweets — tweepy 3.8.0 documentation

前回作成したRandom Forestによる判定器や、ONNX版Illustration2Vecをhello/以下に置いています。

hello/views.py

import os
import re
import urllib.request
from urllib.parse import urlparse
from PIL import Image
from joblib import dump, load
import tweepy

from django.shortcuts import render
from social_django.models import UserSocialAuth
from django.conf import settings
import more_itertools

import sys
sys.path.append(os.path.dirname(__file__))
import i2v

# ONNX版Illustration2Vec
illust2vec = i2v.make_i2v_with_onnx(os.path.join(os.path.dirname(__file__), "illust2vec_ver200.onnx"))

# 事前に作成しておいた二次元画像判別器
clf = load(os.path.join(os.path.dirname(__file__), "clf.joblib"))

def index(request):
    if request.user.is_authenticated:  # Twitterでログインしている場合
        # ユーザー情報の取得
        user = UserSocialAuth.objects.get(user_id=request.user.id)
        consumer_key = settings.SOCIAL_AUTH_TWITTER_KEY
        consumer_secret = settings.SOCIAL_AUTH_TWITTER_SECRET
        access_token = user.extra_data['access_token']['oauth_token']
        access_secret = user.extra_data['access_token']['oauth_token_secret']
        auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
        auth.set_access_token(access_token, access_secret)
        api = tweepy.API(auth)
        # 全文を取得するためにextendedを指定する
        timeline = api.home_timeline(count=200, tweet_mode = 'extended')

        tweet_illust = []
        for tweet in timeline:
            if 'media' in tweet.entities:
                media  = tweet.extended_entities['media'][0]
                media_url = media['media_url']
                filename = os.path.basename(urlparse(media_url).path)
                filename = os.path.join(os.path.dirname(__file__), 'images', filename)
                urllib.request.urlretrieve(media_url, filename)
                img = Image.open(filename)
                feature = illust2vec.extract_feature([img])
                prob = clf.predict_proba(feature)[0]
                if prob[1] > 0.4:  # 二次元イラストの可能性が高い
                    if hasattr(tweet, "retweeted_status"): 
                        profile_image_url = tweet.retweeted_status.author.profile_image_url_https
                        author = {'name': tweet.retweeted_status.author.name,
                                  'screen_name': tweet.retweeted_status.author.screen_name}
                        id_str = tweet.retweeted_status.id_str
                    else:
                        profile_image_url = tweet.author.profile_image_url_https
                        author = {'name': tweet.author.name,
                                  'screen_name': tweet.author.screen_name}
                        id_str = tweet.id_str
                    # リツイート元のツイート全文の取得
                    try:
                        text = tweet.retweeted_status.full_text
                    except AttributeError:
                        text = tweet.full_text
                    # 画像URLを削除するために文末のURLを削除する
                    text = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+$", '', text).rstrip()
                    tweet_illust.append({'id_str': id_str, 
                                         'profile_image_url': profile_image_url,
                                         'author': author,
                                         'text': text,
                                         'image_url': media_url})
        # 表示の都合上4つずつのグループに分ける
        tweet_illust_chunked = list(more_itertools.chunked(tweet_illust, 4))
        return render(request,'hello/index.html', {'user': user, 'timeline_chunked': tweet_illust_chunked})
    else:
        return render(request,'hello/index.html')

これを表示するためのHTMLを示します。Bulmaで画像の中央を丸く切り取って並べる - kivantium活動日記の応用です。Bulmaのカード機能を使っています。 bulma.io

<!doctype html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>にじさーち</title>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.8.0/css/bulma.min.css">
    <style>
img { object-fit: cover; }
.bm--card-equal-height {
   display: flex;
   flex-direction: column;
   height: 100%;
}
.bm--card-equal-height .card-footer {
   margin-top: auto;
} </style>
  </head>
  <body>
  <nav class="navbar is-primary">
    <div class="navbar-brand navbar-item">
      <h1 class="title has-text-light">にじさーち</h1>
    </div>
    {% if request.user.is_authenticated %}
    <div class="navbar-end">
      <div class="navbar-item">
      <a class="button is-light" href="/logout">ログアウト</a>
      </div>
    {% endif %}
    </div>
  </nav>
  <section class="section">
    <div class="container">
      {% if request.user.is_authenticated %}
      {% for tweets in timeline_chunked %}
      <div class="columns is-mobile">
        {% for tweet in tweets %}
        <div class="column is-3">
          <div class="card bm--card-equal-height">
            <div class="card-content">
              <div class="media">
                <div class="media-left">
                  <figure class="image is-48x48">
                    <img src="{{ tweet.profile_image_url }}" alt="Profile image">
                  </figure>
                </div>
                <div class="media-content">
                  <p class="title is-4">{{ tweet.author.name }}</p>
                  <p class="subtitle is-6">@{{ tweet.author.screen_name }}</p>
                </div>
              </div>
              <div class="card-image">
                <figure class="image is-square">
                  <a href="https://twitter.com/i/web/status/{{ tweet.id_str }}" target="_blank" rel="noopener noreferrer">
                    <img src="{{ tweet.image_url }}" alt="main image">
                  </a>
                </figure>
              </div>
              <div class="content">{{ tweet.text }}</div>
            </div>
          </div>
        </div>
        {% endfor %}
      </div>
      {% endfor %}
      {% else %}
      <p>あなたはログインしていません</p>
      <button type="button" onclick="location.href='{% url 'social:begin' 'twitter' %}'">Twitterでログイン</button>
      {% endif %}
    </div>
  </section>
  </body>
</html>

今後の課題

Illustration2Vecのモデルが重い

今回作成したアプリをサーバーにデプロイしようと思ったのですが、Illustration2Vecのモデルがサーバーのメモリサイズよりも大きかったためデプロイすることができませんでした。また、今後複数のユーザーによる使用をサポートしようとするとアクセスが来るたびにIllustration2Vecを実行していてはとても追いつかないのでモデルを軽量化することが必要になりそうです。

画像データベースとしての利用

Twitter APIのRate Limitが厳しいため、タイムラインから一度に収集できるツイートは200件くらいしかありません。これでは大量の画像を閲覧する目的に向きません。そのため、status/filterで常に画像を収集しておき画像データベースとして利用することが考えられます。しかし、類似のサービスが(利用規約に則っているにも関わらず)以前大炎上したことがあるっぽいので、Twitterクライアントとして一般に認められる以上の機能を提供するとなると面倒くさそうです。 nlab.itmedia.co.jp

自動タグ付け

Illustration2Vecでもタグ付けを行うことができますが、つけることができるタグの種類は有限です。新しく増える作品やキャラに対応するために何らかの方法で類似画像のハッシュタグからタグを推定して自動タグ付けができるとよさそうです。(これも絵師界隈の自主ルールで難癖つけられて面倒なことになりそうですが……)

Display requirementsへの適合

利用規約に則ってTwitterのコンテンツを表示する際の条件としてDisplay Requirementsというものがあります。 developer.twitter.com

ツイート本文を全文表示しないといけないとか、Twitterのロゴを右上に表示しないといけないなどといったユーザーの利便性を損なう規定なのですが、規定なので従う必要があります。 Google画像検索のようにサムネイルだけ表示してあとはツイートへのリンクにする方式にすることも含めて検討していきたいです。

二次元画像判別器に対するActive Learning導入の検討

前回の記事では、Twitter上の画像から二次元画像を選ぼうとすると二次元とも三次元とも言い難い画像が入ってくる問題があることを見ました。今回は、Active Learningという方法を使って境界領域の画像をうまく扱う方法を適用したいと思います。

Active Learningについて

Active Learningという言葉は教育業界と機械学習業界の両方で使われているので混乱がありますが、ここでは機械学習でのActive learningを指します。通常の機械学習の問題設定では学習データは既に与えられたものとして扱うことが多いですが、Active Leaningではどのデータを学習するかを選ぶことができるという設定のもとで学習を行います。これにより、少ないデータ数で学習が行えるようになることが期待できます。

f:id:kivantium:20200418181517p:plain:w600
Active Learningでは、境界に近いデータを能動的に選ぶことで効率的に学習を行うことを目指す。
ICML 2019のActive Learningチュートリアルのスライドより。)

以下、Active Learning Literature Surveyの内容に沿って話を進めます。

Active Learningの主なシナリオには3つあります。

  • Membership Query Synthesis: 学習器が入力空間中の任意のラベルなしインスタンスについてラベル付けを要求できる(新しく生成したインスタンスでも良い)
  • Stream-Based Selective Sampling: 1つずつ流れてくるデータそれぞれについてラベルを要求するか破棄するかを決める
  • Pool-Based Sampling: ラベル付きデータとラベルなしデータが与えられ、ラベルなしデータの中からどのデータにラベル付けを要求するか決める

どのデータに対してラベルを要求するかを決定する基準として最もよく使われているのがUncertainty Samplingという方式で、主なものが3種類あります。

  • least confident: 一番確信度が低いものを選ぶ。数式で書くと、 1 − P(\hat{y}|x) が最大のものを選ぶ(\hat{y} = \mathrm{arg max}_y P(y|x))。
  • margin sampling: 一番可能性が高いクラスと二番目に可能性が高いクラスの分類確率の差が一番小さいものを選ぶ。数式で書くと、 P(\hat{y}_1|x) − P(\hat{y}_2|x) が最小のものを選ぶ。
  • entropy: エントロピーが最大のものを選ぶ。数式で書くと、 -\sum_{i} P(y_i|x) \log{P(y_i|x)}が最大のものを選ぶ。

二次元画像判別に対する応用

今回は、ラベル付けを行った画像とラベルがついていない画像が与えられているのでPool-Based Samplingのシナリオになります。とりあえず一番簡単そうなmargin samplingを使って、昨日ラベル付けをサボったデータに対してActive Learningをやってみようと思ったのですが、1個ずつラベル付けするのは面倒なので、分類確率の差が0.3より小さいデータがどんな感じのデータになるのかを見てみることにします。

import os
import shutil

import numpy as np
from PIL import Image
import more_itertools
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from tqdm import tqdm

import i2v

illust2vec = i2v.make_i2v_with_onnx("illust2vec_ver200.onnx")

# 学習データの準備
X = []
y = []
batch_size = 4

negative_path = '0'
negative_list = os.listdir(negative_path)
for batch in tqdm(list(more_itertools.chunked(negative_list, batch_size))):
    img_list = [Image.open(os.path.join(negative_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([0] * len(batch))

positive_path = '1'
positive_list = os.listdir(positive_path)
for batch in tqdm(list(more_itertools.chunked(positive_list, batch_size))):
    img_list = [Image.open(os.path.join(positive_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([1] * len(batch))

# Random Forestの学習
clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

# Unlabeled データをフォルダ分けする
pool_path = 'unlabeled'
pool_list = os.listdir(pool_path)
for filename in pool_list:
    filename = os.path.join(pool_path, filename)
    img = Image.open(filename)
    feature = illust2vec.extract_feature([img])
    prob = clf.predict_proba(feature)[0]
    # 確率値の差が0.3以下ならラベル付けを要求する
    if np.abs(prob[0]-prob[1]) < 0.3:
        shutil.move(filename, 'uncertain')
    elif prob[0] > prob[1]:
        shutil.move(filename, 'negative')
    else:
        shutil.move(filename, 'positive')

Unlabeledデータ2021枚のうち、uncertainに分類されたものが193枚、negativeに分類されたものが1783枚、positiveに分類されたものが52枚でした。

f:id:kivantium:20200418191154p:plain:w600
紛らわしいと判定された画像

uncertainに分類された画像をさらに詳しく見てみました。

前回の記事で述べた紛らわしい種類の画像がきちんとuncertainに分類されており、Random Forestによる分類確率が紛らわしさをきちんと捉えていることが確認できました。

positiveに分類された画像はアニメのスクリーンショット1枚を除いて全てイラストでした。

f:id:kivantium:20200418190731p:plain:w600
二次元イラストだと判定された画像

negativeに分類された画像のうちイラストは34枚でした。これらの画像は、コントラストが薄めである・人間がたくさん書かれているなどの理由から漫画と間違えられた可能性が高いと思っています。(今回のラベリングではコマ割りがあるまたは白黒の画像は全て二次元イラストではないとしています)

f:id:kivantium:20200418190429p:plain:w600
二次元イラストではないと間違えて判定されたイラスト

以上の結果から、margin samplingは二次元画像分類の境界ケースをきちんと集めることができそうだという感触を得ました。これを学習データに加えたら精度が上がったという実験結果を出せればよかったのですが、ランダムサンプリングでも95%くらいの精度が出ていたのでActive Learningで有意差を出すことが難しそうでした。Active Learningをするというよりは、棄却オプションをつけて不確かな画像は人手で分類するようにするのが良さそうです。

次回はこの結果を使って二次元画像だけのタイムラインを表示するアプリを作ろうと思います。

二次元画像判別器の作成に関する基礎検討

Abstract

Twitterに流れる大量の画像の中から二次元画像を集めることは私のQoL向上の上で非常に重要な問題である。 本研究では、著者のタイムラインに実際に流れてきた画像を分析し、二次元画像分類という問題の定義が難しいことを示した。 また、独自に定義した問題設定に基づいてデータセットを作成し、Illustration2VecとRandom Forestを組み合わせることで accuracy 0.98 を達成するモデルを作成した。 これにより今後の二次元画像収集に関する研究の方針が明らかになった。

Introduction

Twitterのタイムラインには非常に多くの画像が流れており、その中から自分の好みの二次元画像を収集することは非常に困難である。神絵師アカウントが投稿する画像はイラストばかりとは限らず、むしろ焼肉等の飯テロ画像やソシャゲのガチャ結果などの非イラスト画像の割合のほうが高い場合が多い。そのため、神絵師を集めたリストを眺めても目的の二次元画像が得られる割合が少なく、快適な画像収集ライフの妨げとなっている。

二次元画像の自動判定の問題は、オタク機械学習界隈ではよく知られている。Ideyoshiらは画素値ヒストグラムを特徴量とした機械学習モデルを用いて二次元画像判別器を作成した [1]。また、TachibanaはIllustration2Vecを元にしたニューラルネットワークを用いて二次元画像判別器を作成した [2]。

これらの研究はいずれも二次元画像として人が収集したTumblrデータ、三次元画像としてImagenetのデータを用いていた。この画像データセットにおいては二次元・三次元の区別が明確であるが、実際にタイムラインに流れてくる画像は二次元・三次元の区別が明確ではないものが数多く存在する。

本研究では、実際にタイムラインに流れてくる画像を収集し、二次元画像判別問題を正しく定義することが難しいことを示した。また、困難を回避できる独自のデータ収集基準を設定した上でデータセットの作成を行い、Illustration2VecとRandom Forestを組み合わせたモデルを用いてaccuracy 0.98を達成した。

  • [1] Mori Ideyoshi, Falsita Fawcett, Fall Through, Makoto Sawatar. 機械学習による二次元/三次元画像判別. SIG2D'13, pp. 1-6, 2013.
  • [2] Hazuki Tachibana. Illustration2Vec に基づく高精度な二次元画像判別器の作成. SIG2D'15, pp. 9-12, 2015.

Method

Twitter画像の収集

以前の記事で説明したstatuses/filterを用いて、followにkivantiumのフォロワー全員を指定して画像の収集を行った。収集期間は4月16日深夜〜4月18日昼頃である。

kivantium.hateblo.jp

二次元画像判別器の作成

以前の記事で作成した、ONNX経由でIllustration2Vecを作成するライブラリを用いて画像から特徴量を抽出し、Random Forestで分類するモデルを作成した。

ネガティブデータを0、ポジティブデータを1というディレクトリに入れて、次のようなPythonプログラムを実行した。

import os
from pprint import pprint

import numpy as np
from PIL import Image
import more_itertools
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from tqdm import tqdm

import i2v

illust2vec = i2v.make_i2v_with_onnx("illust2vec_ver200.onnx")

X = []
y = []
batch_size = 4

negative_path = '0'
negative_list = os.listdir(negative_path)
for batch in tqdm(list(more_itertools.chunked(negative_list, batch_size))):
    img_list = [Image.open(os.path.join(negative_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([0] * len(batch))

positive_path = '1'
positive_list = os.listdir(positive_path)
for batch in tqdm(list(more_itertools.chunked(positive_list, batch_size))):
    img_list = [Image.open(os.path.join(positive_path, f)) for f in batch]
    features = illust2vec.extract_feature(img_list)
    X.extend(features)
    y.extend([1] * len(batch))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))

Result

Twitter画像の収集

収集された画像のうち、明らかに二次元画像と判定されるものの例を示す。

かわいい。

このように、明らかに二次元画像と判定できる画像も存在したが、一方で二次元画像と呼ぶべきなのかよく分からないものが数多く存在した。一例を上げると

  • 写真とイラストを並べた画像
LƏGS from r/Animemes

また、既存の漫画のコマやアニメ・ソシャゲのスクリーンショットなどは二次元画像であるが、特に収集したい対象ではない。

以上を踏まえて、以下のデータセット作成基準を制定した。

  • かわいい二次元画像だと私が判定したカラーの一枚絵のイラストをポジティブデータとする
  • 二次元画像だが収集したくないと判定した二次元画像と、コマ割りがあるか白黒の二次元画像は保留データとする
  • それ以外の画像をネガティブデータとする。

このような非科学的な根拠に基づいてデータセットを作成しているので以下の評価は全てデタラメである。

保留データをネガティブデータとして利用すると、ポジティブ・ネガティブの差があいまいになって学習が進みにくくなりそうだったので、とりあえず明確に区別できそうなポジティブとネガティブのデータだけを用いて評価することにした。

二次元画像判別器の評価

収集した6000枚くらいの画像のうち、以上の基準に基づいて700枚くらいをポジティブデータ、2000枚くらいをネガティブデータ、1300枚くらいを保留データとした。残りの2000枚くらいの画像はラベリングに飽きたので放置してある。

そのうち200枚のポジティブデータと350枚のネガティブデータを用いて上記のプログラムを動かして評価したところ、accuracy 0.98という結果を得た。せっかくラベリングしたのに一部のデータしか使わなかったのは思ったより特徴量抽出に時間が掛かって寝る前に学習が終わらなさそうだったからである。

Discussion

抽出した特徴量を何もチューニングしてないRandom Forestに入れるだけでaccuracy 0.98が出ることから、二次元画像判別問題は難しい入力ケースを除けばだいぶ簡単な問題であったことが分かる。(自明なことを言っているだけだな)

上記で学習したランダムフォレストを試しに保留データに適用してみたところ、アニメやソシャゲのスクリーンショットは二次元画像として判定されてしまった。学習データ中にソシャゲのスクリーンショットは必要ではないという情報がない以上、これは当然の結果である。現実のデータでは二次元画像と三次元画像が明確に分かれていないことが判明した以上、必要となるのは二次元画像のうち収集したいタイプの画像を区別する何らかの方法を考えることである。人間が見ればアニメ・ソシャゲのスクリーンショットと創作イラストの区別は明確であることがほとんどであるため、機械学習さんにも区別できる可能性はある。

ポジティブとネガティブの境界をどうするかというのは当然機械学習で重要な問題であるため、先行研究が存在する。そのうち特に面白そうなのはアクティブ・ラーニングと呼ばれる分野である。アクティブラーニングでは確信が持てない訓練データを選び出して人間にラベリングさせることで少ないラベル数で効率よく学習させることを目指しているらしいので、境界が問題になる現状を解決する手がかりになる気がする。今週末のテーマとして勉強してみるのがいいだろう。

広告コーナー