EnsekiTT Blog

EnsekiTTが書くブログです。

DCGANをChainerのTrainerで学習して100連MNISTガチャを回した話

つまりなにしたの?

高解像度GANができるようになったという話をきいたけど基礎が抜けてるのでDCGANをChainerのTrainerを使って作ってみた。
作ってる途中で公式がDCGANのTrainer使った実装を公開していることを知るものの写経も辞さない構えで作った。
f:id:ensekitt:20171101002149j:plain
*1

GANとは?

贋作を作るネットワーク(生成モデル)と贋作と真作を見破るネットワーク(識別モデル)にイタチごっこをさせて贋作を作るネットワークの性能を上げていくもの。*2

  • 生成モデルGはノイズzを受け取って贋作を作る
  • 識別モデルDは真作xか生成モデルが作った贋作G(z)を受け取って識別する

概念的にはこんな感じ

for イテレーション回数 do
	m個のノイズサンプルzを作る
	m個の贋作G(z)を作る
	m個の真作を取り出す
	見破るモデルにG(z)とxを渡して誤差を計算する

	m個のノイズサンプルzを作る
	m個の贋作G(z)を作る
	見破るモデルにG(z)を渡して誤差を計算する

	それぞれの誤差を最適化機のアップデータに渡してネットワークを更新する
end for

Trainerに渡すUpdaterの部分だけ切り出してきた

    def update_core(self):
        gen_optimizer = self.get_optimizer('gen') # 生成モデル用Optimizerを用意する
        dis_optimizer = self.get_optimizer('dis') # 識別モデル用Optimizerを用意する
        
        batch = self.get_iterator('main').next() # バッチで教師データを取り出す
        x_real = Variable(self.converter(batch, self.device))
        batch_size = len(x_real) #概念コードのmにあたる
        
        y_real = dis(x_real) # 真作と識別されなければいけない識別結果
        
        z = xp.random.uniform(-1, 1, (batch_size, self.gen.z_dim)) # ノイズの生成
        z = z.astype(dtype=xp.float32)
        x_fake = gen(z) # 贋作を作っている
        y_fake = dis(x_fake) # 贋作と識別されなければいけない識別結果
        
        dis_optimizer.update(self.dis_loss, y_fake, y_real) # 真作と贋作の識別結果で識別モデルのアップデート
        gen_optimizer.update(self.gen_loss, y_fake) # 贋作の識別結果で生成モデルのアップデート

実行結果

MNISTのデータセットを使って1000Epoch バッチサイズ100で学習させてみた。

100 epoch 目の100連ガチャ結果がこちら!!!

f:id:ensekitt:20171101001329p:plain
え、なんかもうほとんど出来てる

1000 epoch 目の100連ガチャの結果がこちら!!!

f:id:ensekitt:20171101001409p:plain
え、なんかあんまり変わってない。ほとんどノーマル。レア…位の読みやすさのやつはあるかも。
よく考えたらMNISTってもともとハイパー読みにくいの含まれてるよね???

とはいえ、実行結果を見るとGen側のOptimizeがうまくいってない気がする。

*1:unsplash.com

*2:GANは贋作のGANではなくGenerative Adversarial Networkの略である。

クリエイティブ・コモンズ・ライセンス
この 作品 は クリエイティブ・コモンズ 表示 4.0 国際 ライセンスの下に提供されています。