1日目はなにしたの?
データのアルゴリズムとロジスティクスの話をまとめて、
Resnetのロジスティクスを読み解き、データセットの準備の部分まで到達した。
ensekitt.hatenablog.com
実行するときのコマンド
python cifar10_download_and_extract.py python cifar10_main.py
こんな感じで動かしている。
今日はこの2行目
2行目でやっていること
python cifar10_main.py
中身を見ていくと
mainで実行されているのはこの2行
if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()
なるほどわからん。
Tfのログレベルを設定して、
Tfのアプリを走らせている。
で、TensorFlowのコードを読んでみると
https://github.com/tensorflow/tensorflow/blob/9dc6c17797c065796603d9259b2aa57b3c07ff71/tensorflow/python/platform/app.py#L31-L48
mainって関数があったらそれを走らせている様子。
中身を見ていく
# Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
環境変数にTensorFlow向けの設定をしてる。Winogradはフィルタサイズが3x3とかの
行列積をちょっと早く計算できるアルゴリズム。*2
cifar_classifier = tf.estimator.Estimator( model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir)
TensorFlowの推定器のインスタンスを作成している。
モデル(関数)とモデルディレクトリ(モデルの途中経過と結果を保存する用)を指定している。
https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
cifar10_model_fnについてはアルゴリズム部分なので別途実施。
for cycle in range(FLAGS.train_steps // FLAGS.steps_per_eval): tensors_to_log = { 'learning_rate': 'learning_rate', 'cross_entropy': 'cross_entropy', 'train_accuracy': 'train_accuracy' } logging_hook = tf.train.LoggingTensorHook( tensors=tensors_to_log, every_n_iter=100) cifar_classifier.train( input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, batch_size=FLAGS.batch_size), steps=FLAGS.steps_per_eval, hooks=[logging_hook]) # Evaluate the model and print results eval_results = cifar_classifier.evaluate( input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, batch_size=FLAGS.batch_size)) print(eval_results)
定期ステップ毎に以下を実施してる。
- ログで何を出すかを決めて、ログを実施するタイミング(フック)を指定している
- 学習を行っている
- 評価を行ってその結果を表示してる
def input_fn(mode, batch_size): """Input_fn using the contrib.data input pipeline for CIFAR-10 dataset. Args: mode: Standard names for model modes (tf.estimators.ModeKeys). batch_size: The number of samples per batch of input requested. """ dataset = record_dataset(filenames(mode)) # For training repeat forever. if mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.repeat() dataset = dataset.map(dataset_parser, num_threads=1, output_buffer_size=2 * batch_size) # For training, preprocess the image and shuffle. if mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.map(train_preprocess_fn, num_threads=1, output_buffer_size=2 * batch_size) # Ensure that the capacity is sufficiently large to provide good random # shuffling. buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size dataset = dataset.shuffle(buffer_size=buffer_size) # Subtract off the mean and divide by the variance of the pixels. dataset = dataset.map( lambda image, label: (tf.image.per_image_standardization(image), label), num_threads=1, output_buffer_size=2 * batch_size) # Batch results by up to batch_size, and then fetch the tuple from the # iterator. iterator = dataset.batch(batch_size).make_one_shot_iterator() images, labels = iterator.get_next() return images, labels
input_fnはモードとバッチサイズを与えると画像とラベルを返してくれる関数
モードは基本的に学習(TRAIN)、評価(EVAL)、予測(PREDICT)の3パターン。
こんな関数だった。
学習の時はかなりロジスティクス的な部分がおおい。
− バイナリデータからのパース
− 学習データへのプリプロセスの適用(学習のみ、詳細後述)
− データセットのシャッフル(学習のみ)
− データセットの正規化
− イテレータの作成
データへのプリプロセスの適用はこちら
def train_preprocess_fn(image, label): """Preprocess a single training image of layout [height, width, depth].""" # Resize the image to add four extra pixels on each side. image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8) # Randomly crop a [HEIGHT, WIDTH] section of the image. image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) # Randomly flip the image horizontally. image = tf.image.random_flip_left_right(image) return image, label
これの中を見るとデータオーグメンテーションをしていることもわかる。
左右を反転させたり画像の一部を切り落としたりリサイズしたりしている。
これによって判別のロバスト性を向上している。
ここはアルゴリズムかも知れない。
でもまあネットワークの外でドメイン知識が反映される部分でもある。
(image操作なのにTensorFlowが用意してるんだなぁ…と思った。幅広い。)
概ねこんな感じで学習・評価までたどり着いた。
*1:Photo by unsplash.com
*2:アルゴリズムって聞くと勘違いしがちだけど、Wino Grad(勾配)ではなくDr. Shmuel Winogradのアルゴリズム。そもそも勾配法による最適化じゃなくて行列積だ。 中身読んでないけどこれが実装されてるんじゃないかな。 https://arxiv.org/pdf/1509.09308.pdf