EnsekiTT Blog

EnsekiTTが書くブログです。

Chainerで転移学習するときに、新たなデータで既に学習済のところを変更されないように固定する話

つまりなにしたの?

転移学習をすすめるにあたって、最後の層だけ学習して、それ以外の層はそのままにしたい。
1つの手としては、

model(inputs={'data': X}, outputs=['loss3/classifier'])

のoutputsを最終層の手前までにして事前にデータを変換してからそれに対して最終層を学習するモデルにする方法。
それとは別に、特定の層の最適化を行わない方法が検討できる。
今回は、特定の層の最適化を行わない方法をメモしておく。

f:id:ensekitt:20171130022359p:plain

転移部分をコピーする

例なので、全部はやっていないけど適度にやってみた。(不十分)

ggn.conv1.W = model['conv1/7x7_s2'].W
ggn.conv1.b = model['conv1/7x7_s2'].b
ggn.conv2.W = model['conv2/3x3'].W
ggn.conv2.b = model['conv2/3x3'].b
ggn.conv2_reduce.W = model['conv2/3x3_reduce'].W
ggn.conv2_reduce.b = model['conv2/3x3_reduce'].b
ggn.inception_3a.conv1.W = model['inception_3a/1x1'].W
ggn.inception_3a.conv1.b = model['inception_3a/1x1'].b
# 続く

オプティマイザの定義する

今回はGoogLeNetを定義してそれをAdamに適用する準備をしてみた。

optimizer = chainer.optimizers.AdaGrad()
optimizer.setup(ggn)

更新を許容するか決める

ggn.conv1の層のupdateを許容するか決めるコードがこちら。

ggn.conv1.enable_update()
ggn.conv1.update_enabled
# True

ggn.conv1.disable_update()
ggn.conv1.update_enabled
# False

とりあえず、これで更新されなくなる。

ggn.disable_update()
ggn.update_enabled
# False

これでネットワーク全部が一旦enableになるので、転移学習の場合は、
変化させるところをEnableする方が良い気がする。

作業ログ的過ぎる()

ちょっと全体の話が出来なくてすごくこまい話ばっかりになってる。

スポンサーリンク