Illustration2VecをONNX経由で使う
趣味プロジェクトでIllustration2Vecを使いたくなったのですが、これは2015年の論文なのでモデルをCaffeかChainerで使うことになっています。 github.com
残念ながらCaffeもChainerも既に開発が終了しているため、Illustration2VecのモデルをONNXという共通フォーマットに変換して今後も使えるようにしました。 利用方法だけ知りたい人は「モデルの変換」を飛ばして「使い方」を見てください。
モデルの変換
まずはオリジナルのIllustration2Vecのモデルをダウンロードします。以下を実行するとCaffeのモデルがダウンロードできます。
git clone https://github.com/rezoo/illustration2vec.git cd illustration2vec ./get_models.sh
このスクリプトでは特徴抽出モデルのprototxtがダウンロードできなかったので、Illustration2VecのInternet ArchiveからNetwork configuration file (feature vectors) illust2vec.prototxt
を追加でダウンロードしました。
必要なライブラリをインストールします。
pip install onnx coremltools onnxmltools
以下のPythonスクリプトを実行すると、タグ予測モデルのillust2vec_tag_ver200.onnx
と特徴ベクトル抽出モデルのillust2vec_ver200.onnx
が作成されます。
import os import onnx import coremltools import onnxmltools # CaffeモデルをONNX形式で読み込む関数 # https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/converter_scripts/caffe_coreml_onnx.ipynb def caffe_to_onnx(proto_file, input_caffe_path): output_coreml_model = 'model.mlmodel' # 中間ファイル名 # 中間ファイルが既に存在したら例外を送出 if os.path.exists(output_coreml_model): raise FileExistsError('model.mlmodel already exists') # CaffeのモデルをCore MLに変換 coreml_model = coremltools.converters.caffe.convert( (input_caffe_path, proto_file)) # Core MLモデルを保存 coreml_model.save(output_coreml_model) # Core MLモデルを読み込む coreml_model = coremltools.utils.load_spec(output_coreml_model) # Core MLモデルをONNXに変換 onnx_model = onnxmltools.convert_coreml(coreml_model) # Core MLモデルを削除 os.remove(output_coreml_model) return onnx_model # タグ予測モデルの変換・保存 onnx_tag_model = caffe_to_onnx( 'illust2vec_tag.prototxt', 'illust2vec_tag_ver200.caffemodel') onnxmltools.utils.save_model(onnx_tag_model, 'illust2vec_tag_ver200.onnx') # 特徴ベクトル抽出モデルの変換・保存 onnx_model = caffe_to_onnx('illust2vec.prototxt', 'illust2vec_ver200.caffemodel') # encode1レイヤーをONNXから利用できるようにする # https://github.com/microsoft/onnxruntime/issues/2119 intermediate_tensor_name = 'encode1' intermediate_layer_value_info = onnx.helper.ValueInfoProto() intermediate_layer_value_info.name = intermediate_tensor_name onnx_model.graph.output.extend([intermediate_layer_value_info]) onnx.save(onnx_model, 'illust2vec_ver200.onnx')
使い方
上のようにしてONNX形式に変換したモデルと、それを利用するためのコードを用意しました。 github.com
ONNX形式を他のフレームワークで読み込んで実行してもいいのですが、ONNX RuntimeというMicrosoft製のパフォーマンスを重視した推論専用のライブラリがあったのでこれを使うことにしました。 UbuntuでONNX RuntimeをCPU向けにインストールするコマンドは以下の通りです。
sudo apt install libgomp1 pip install onnxruntime
例を実行する前にコードとpre-trainedモデルのダウンロードを行ってください。
git clone https://github.com/kivantium/illustration2vec.git cd illustration2vec ./get_onnx_models.sh
タグ予測
コード
import i2v from PIL import Image from pprint import pprint illust2vec = i2v.make_i2v_with_onnx( "illust2vec_tag_ver200.onnx", "tag_list.json") img = Image.open("images/miku.jpg") pprint(illust2vec.estimate_plausible_tags([img], threshold=0.5))
入力
Hatsune Miku (初音ミク), © Crypton Future Media, INC., http://piapro.net/en_for_creators.html. This image is licensed under the Creative Commons - Attribution-NonCommercial, 3.0 Unported (CC BY-NC).
出力
[{'character': [('hatsune miku', 0.9999994039535522)], 'copyright': [('vocaloid', 0.9999999403953552)], 'general': [('thighhighs', 0.9956372976303101), ('1girl', 0.9873461723327637), ('twintails', 0.9812833666801453), ('solo', 0.9632900953292847), ('aqua hair', 0.9167952537536621), ('long hair', 0.8817101716995239), ('very long hair', 0.8326565027236938), ('detached sleeves', 0.7448851466178894), ('skirt', 0.6780778169631958), ('necktie', 0.560835063457489), ('aqua eyes', 0.5527758598327637)], 'rating': [('safe', 0.9785730242729187), ('questionable', 0.02053523063659668), ('explicit', 0.0006299614906311035)]}]
Chainer版とほとんど同じ結果が出力されました。Chainerではこの処理に6秒かかっていましたが、onnx-runtimeだと2秒で実行できたのでたしかにパフォーマンスにも優れているようです(ChainerではCaffeのモデルを変換する手間が掛かっているので1枚を処理する時間で比較するのは公平ではないですが)。
特徴ベクトルの抽出
コード
import i2v from PIL import Image illust2vec = i2v.make_i2v_with_onnx("illust2vec_ver200.onnx") img = Image.open("images/miku.jpg") # extract a 4,096-dimensional feature vector result_real = illust2vec.extract_feature([img]) print("shape: {}, dtype: {}".format(result_real.shape, result_real.dtype)) print(result_real) # i2v also supports a 4,096-bit binary feature vector result_binary = illust2vec.extract_binary_feature([img]) print("shape: {}, dtype: {}".format(result_binary.shape, result_binary.dtype)) print(result_binary)
先ほどと同じ入力に対する出力
shape: (1, 4096), dtype: float32 [[ 7.474596 3.6860986 0.537967 ... -0.14563629 2.7182112 7.3140917 ]] shape: (1, 512), dtype: uint8 [[246 215 87 107 249 190 101 32 187 18 124 90 57 233 245 243 245 54 229 47 188 147 161 149 149 232 59 217 117 112 243 78 78 39 71 45 235 53 49 77 49 211 93 136 235 22 150 195 131 172 141 253 220 104 163 220 110 30 59 182 252 253 70 178 148 152 119 239 167 226 202 58 179 198 67 117 226 13 204 246 215 163 45 150 158 21 244 214 245 251 124 155 86 250 183 96 182 90 199 56 31 111 123 123 190 79 247 99 89 233 61 105 58 13 215 159 198 92 121 39 170 223 79 245 83 143 175 229 119 127 194 217 207 242 27 251 226 38 204 217 125 175 215 165 251 197 234 94 221 188 147 247 143 247 124 230 239 34 47 195 36 39 111 244 43 166 118 15 81 177 7 56 132 50 239 134 78 207 232 188 194 122 169 215 124 152 187 150 14 45 245 27 198 120 146 108 120 250 199 178 22 86 175 102 6 237 111 254 214 107 219 37 102 104 255 226 206 172 75 109 239 189 211 48 105 62 199 238 211 254 255 228 178 189 116 86 135 224 6 253 98 54 252 168 62 23 163 177 255 58 84 173 156 84 95 205 140 33 176 150 210 231 221 32 43 201 73 126 4 127 190 123 115 154 223 79 229 123 241 154 94 250 8 236 76 175 253 247 240 191 120 174 116 229 37 117 222 214 232 175 255 176 154 207 135 183 158 136 189 84 155 20 64 76 201 28 109 79 141 188 21 222 71 197 228 155 94 47 137 250 91 195 201 235 249 255 176 245 112 228 207 229 111 232 157 6 216 228 55 153 202 249 164 76 65 184 191 188 175 83 231 174 158 45 128 61 246 191 210 189 120 110 198 126 98 227 94 127 104 214 77 237 91 235 249 11 246 247 30 152 19 118 142 223 9 245 196 249 255 0 113 2 115 149 196 59 157 117 252 190 120 93 213 77 222 215 43 223 222 106 138 251 68 213 163 57 54 252 177 250 172 27 92 115 104 231 54 240 231 74 60 247 23 242 238 176 136 188 23 165 118 10 197 183 89 199 220 95 231 61 214 49 19 85 93 41 199 21 254 28 205 181 118 153 170 155 187 60 90 148 189 218 187 172 95 182 250 255 147 137 157 225 127 127 42 55 191 114 45 238 228 222 53 94 42 181 38 254 177 232 150 99]]
Chainer版と同じbinary vectorが出力されていました。
次回はこれを使ってイラストの機械学習をします。