EnsekiTT Blog

EnsekiTTが書くブログです。

機械学習とかのアルゴリズムとデータのロジスティクスの話。3日目(tf.Estimatorのオレオレ解釈訳を添えて)

1日目はなにしたの?

データのアルゴリズムロジスティクスの話をまとめて、
Resnetのロジスティクスを読み解き、データセットの準備の部分まで到達した。
ensekitt.hatenablog.com

2日目はなにしたの?

Resnetの学習部分のロジスティクスを読み解き、Resnetにデータを渡して学習するまでに何をしているのか?を知った。
ensekitt.hatenablog.com

つまりなにしたの?

Resnetのロジスティクスを支えるTensorFlow.Estimatorのドキュメントを読んで、保存してある画像に対して予測するコードを書いた。
実装するにあたって読んだEstimatorのオレオレ日本語訳も載せます。

f:id:ensekitt:20171002223318j:plain*1

今回実装したコードの実行方法

python cifar10_download_and_extract.py
python cifar10_main.py
python cifar10_try.py --image=/path/of/image #←今回のコード

github.com

主にinput_fnを改造して、引数で入力した画像(32x32x3になっているものとする)を読み込んで
1枚だけど、tf.contrib.Data.Dataset()に入れて、One_shot_iteratorに変換して画像のtensorを返すようにした。

あとは、tf.Estimatorのpredictを走らせて、識別結果を標準出力に提示する。

実装するにあたって読んだTensorFlow.Estimatorのオレオレ解釈訳

tf.estimatorにはtrain, evaluate, predictとexport_savedmodelがある。
定義するときには、モデルを定義した関数(model_fn)が必要。

train

trainを実施するときには、入力関数(input_fn)が必要
 input_fnは(features, labels)のタプルを返すこと。
  featuresはTensorTensorの文字列、Tensorの文字列の辞書
  labelsはラベルがついたTensorTensor辞書

それ以外に、hooks, steps, max_stepsの定義ができる
hooksはSessionRunHookサブクラスインスタンスのリストが定義できる。
 hooksの条件に合致した場合は、Callbackが発生する
 Callbackはインスタンス作成時に定義される

Stepsは学習を行うStep数を定義できる
 Noneの場合は、OutOfRangeかStopIterationのErrorが発生するまで学習し続ける
 Stepsを定義してもOutOfRangeかStopIterationのErrorが発生したら学習は止まる
 実際はtrainを呼び出す回数*Step数が学習を行うStep数になる
  このようなインクリメンタル(何度も呼び出すよう)な実行をしたくない場合はmax_stepsを定義してstepsはNoneにする必要がある
  逆にStepsを定義したらmax_stepsはNoneにする必要がある

max_stepsは学習を行う回数の最大値を定義できる
 Noneの場合は、OutOfRangeかStopIterationのErrorが発生するまで学習し続ける
 max_stepsを定義してもOutOfRangeかStopIterationのErrorが発生したら学習は止まる
 max_stepsを定義したらstepsは必ずNoneでなければならない
 学習を呼び出す回数によらずmax_stepsの値で定義された回数しか学習が行われない
  max_stepsを100に設定したら、学習を2回呼び出してもエラーが無ければ2回目は何もしない

evaluate

evaluateを実施するときには、入力関数(input_fn)が必要
evaluateはステップ毎にinput_fnを呼び出してバッチ特徴量を取得する。Step数分実行して評価を行うか、input_fnが入力限界などのエラーを返すまで実行される。
 input_fnは(features, labels)のタブルを返すこと。
  featuresはTensorTensorの文字列、Tensorの文字列の辞書
  labelsはラベルがついたTensorTensor辞書
それ以外にsteps, hooks, checkpoint_path, nameが定義できる
Stepsは学習を行うStep数を定義できる
 Noneの場合は、OutOfRangeかStopIterationのErrorが発生するまで学習し続ける
 Stepsを定義してもOutOfRangeかStopIterationのErrorが発生したら学習は止まる

hooksはSessionRunHookサブクラスインスタンスのリストが定義できる
 hooksの条件に合致した場合は、Callbackが発生する
 Callbackはインスタンス作成時に定義される

checkpoint_pathは評価を行う特定のチェックポイントのパスが定義できる
 Noneの場合、model_dirの最新のチェックポイントが使用される

Nameは評価そのものに名前を定義できる
対象のデータセットによって名前を変えることで別々のディレクトリに結果が格納され、TensorBoardで利用できる
 

predict

predictは与えられた特徴量に対して予測(predictions tensor)を返す
predictを実施するときには、入力関数(input_fn)が必要
 input_fnは辞書に登録された文字列名で定義されているTensorかSparseTensorの特徴量を返す
  トレーニング向けなどでタプルを返す関数の場合はタプルの1つ目の要素が予測に用いられる
  input_fnがOutOfRangeかStopIteraionのErrorが発生するまで予測をつづける

それ以外にpredict_keys, hooks, checkpoint_pathが定義できる
 predict_keysは文字列のリストで予測するキーの名前を定義できる
  EstimatorSpec.predictionsが辞書型のときに用いられる
  predict_keysが定義されている場合、予測は辞書によってフィルタリングされた結果を返す
  Noneの場合はすべての結果を返す

 hooksはSessionRunHookサブクラスインスタンスのリストが定義できる
  hooksの条件に合致した場合は、Callbackが発生する
  Callbackはインスタンス作成時に定義される

 checkpoint_pathは予測を行う特定のチェックポイントのパスが定義できる
  Noneの場合、model_dirの最新のチェックポイントが使用される
 
この他にモデルをExportするための関数が用意されている。*2
https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator

雑感

正直、訳すの意味ないと思ってたけど、自分なりに動かしながら訳を考えて文字に起こしておくと、後で訳文を読めば実装した時のことを思い出すし、やはり解釈も早いし良いことばかりだった。
僕の母国語は日本語だったし、英語は読む頻度が上がっていても日本語ほどは得意じゃないってことを痛感した。

*1:Photo by unsplash.com

*2:今回自分が使わないから訳さなかった。

クリエイティブ・コモンズ・ライセンス
この 作品 は クリエイティブ・コモンズ 表示 4.0 国際 ライセンスの下に提供されています。