EnsekiTT Blog

EnsekiTTが書くブログです。

Caffe Model ZooのモデルをChainerで読み込んで実行した話

スポンサーリンク

つまりなにしたの?

Caffe model zooのモデルを拾ってきてChainerで使ってみることにしたけど、
とりあえず読み込むところと使うところは出来たっぽいので一旦記事にした。

2017年11月23日追記: 画像の前処理を失敗していました。

ensekitt.hatenablog.com

f:id:ensekitt:20171120010542j:plain

Caffe model zoo

Caffe Model zoo にはCaffeモデル情報をパッケージ化するための標準フォーマットが用意されている。
.caffemodelバイナリとして既に学習済をダウンロード・アップロードすることができる。
github.com

モデル情報のフォーマット

Caffeモデルは、ディレクトリとして配布される。

  • solver/model prototype
  • readme.md

- YAML front matter
 - Caffeのバージョン
 - 学習済.caffemodel のfile URLとSHA1のハッシュ(オプション)
 - github gistのID(オプション)
- データと学習に関する情報
- ライセンス情報

具体的には
ILSVRC-2014 model (VGG team) with 19 weight layers · GitHub

.caffemodelをダウンロードして読み込む。

Chainerには.caffemodelを読み込むモジュールがある。
chainer.links.caffe.CaffeFunction — Chainer 3.2.0 documentation

だた、CaffeFunctionは結構時間がかかるので一度読み込んだら、Pickleにしてしまうのが良いみたい。

import pickle
from chainer.links.caffe import CaffeFunction

MODEL = 'VGG_ILSVRC_19_layers.caffemodel'
PICKLE = 'vgg.pkl'
if os.path.exists(PICKLE):
    print("Load pickle")
    with open(PICKLE, 'rb') as pkl:
        model = pickle.load(pkl)
else:
    print("Load caffemodel and make pickle")
    if os.path.exists(MODEL):
        model = CaffeFunction(MODEL)
        with open(PICKLE, 'wb') as pkl:
            pickle.dump(model, pkl)
    print(MODEL + " not found.")

Pickleがあればそれをロードして、なかったらcaffemodelを読み込んで、Pickleとして保存する。
ここではmodelという変数にcaffemodelがロードされている。

画像データを準備する

VGG-19だから画像をモデルに突っ込む必要がある。
そのための画像を準備する。

# 画像を用意する
img = Image.open('datas/95134bd6cdc3059f1a0d58e0af462242e341ff63.jpg').convert('RGB')
resize_img = img.resize((224,224))

# 平均画像を用意する
mean_image = np.ndarray((3, 224, 224), dtype=np.float32)
mean_image[0] = 103.939
mean_image[1] = 116.779
mean_image[2] = 123.68

# 画像をモデルに入れる準備をする(ここ間違ってます→http://ensekitt.hatenablog.com/entry/2017/11/23/180000)
npar_img = np.asarray(resize_img, dtype=np.uint8).reshape(224,224,3)
npar_img = npar_img.transpose(2,0,1)
X = npar_img.reshape(3, 224, 224)
X = X-mean_image
X = np.ndarray((1,3,224,224), dtype=np.float32)

modelを使う

y, = model(inputs={'data': Variable(X)}, outputs=['fc8'])
prediction = F.softmax(y)
np.argmax(prediction.data)

モデルに適用してその結果をSoftmaxにかけて最大インデックスを取り出してみる。

犬のはずがモスクとかでてきた()

よく考えたら出力のベクトルと対応するラベルリストがわからなくて、ちょっと調べてみたらこんなラベルリストが出てきた。blog/label.txt at master · EnsekiTT/blog · GitHub
これを信じるとすると犬っぽいのを入れたのにモスクっぽい扱いされているっぽい。
ちょっと別のモデルでも試してみたい。

コードはこちら
github.com

クリエイティブ・コモンズ・ライセンス
この 作品 は クリエイティブ・コモンズ 表示 4.0 国際 ライセンスの下に提供されています。