TensorFlowにおける学習結果の保存と再利用

Tensorflowを用いれば簡単にニューラルネットワークの学習を行うことができます。
そこでこの学習した結果を保存し、再利用する方法を調べました。

 

 

 

学習の保存と再利用

TensorFlowによる学習の結果を保存・再利用するためには tf.train.Saver() を用いる方法が簡単です。
tf.train.Saver()を利用すればTensorFlowで作成されたグラフと変数の値を保存することができます。

今回は最も簡単な方法として、学習の最後に一回モデルの保存を行います。

 

学習の保存

学習の保存は tf.train.Saver() の save()を利用します。方法は以下の通りです。

 

これによって dir に学習の結果が保存されます。上記の方法で保存すればモデル全体を保存することができますが、一部だけ保存したい場合にはtf.train.Saver() に保存したい変数を引数として渡すことで実現できます。

 

 

学習の再利用

学習の保存は tf.train.Saver() の restore()を利用します。方法は以下の通りです。

 

tf.train.get_checkpoint_state() は引数に指定した場所に保存したモデルがある確認できます。ある場合はそのモデルを返し、ない場合はNoneを返します。

 

save・restoreの結果を下に示します。

保存と再利用

赤い線が一回目の学習(save)、青い線が一回目を引き継いだ学習です。青い線は赤い線の終わりから学習していることがわかります。

 

関連記事

ニューラルネットワーク・ディープラーニングのまとめ