Tensorflowを用いれば簡単にニューラルネットワークの学習を行うことができます。
そこでこの学習した結果を保存し、再利用する方法を調べました。
Contents
学習の保存と再利用
TensorFlowによる学習の結果を保存・再利用するためには tf.train.Saver() を用いる方法が簡単です。
tf.train.Saver()を利用すればTensorFlowで作成されたグラフと変数の値を保存することができます。
今回は最も簡単な方法として、学習の最後に一回モデルの保存を行います。
学習の保存
学習の保存は tf.train.Saver() の save()を利用します。方法は以下の通りです。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
#モデル定義 ... saver = tf.train.Saver() ... with tf.sess() as sess: #学習 ... saver.save(sess, "dir") |
これによって dir に学習の結果が保存されます。上記の方法で保存すればモデル全体を保存することができますが、一部だけ保存したい場合にはtf.train.Saver() に保存したい変数を引数として渡すことで実現できます。
1 |
tf.train.Saver([W11, W12, W21]) |
学習の再利用
学習の保存は tf.train.Saver() の restore()を利用します。方法は以下の通りです。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
#モデル定義 .... with tf.Sess() as sess: ckpt_state = tf.train.get_checkpoint_state("dir/") if ckpt_state: restore_model = ckpt_state.model_checkpoint_path saver.restore(sess, restore_model) else: init = tf.global_variables_initializer() sess.run(init) #学習 ... |
tf.train.get_checkpoint_state() は引数に指定した場所に保存したモデルがある確認できます。ある場合はそのモデルを返し、ない場合はNoneを返します。
save・restoreの結果を下に示します。
赤い線が一回目の学習(save)、青い線が一回目を引き継いだ学習です。青い線は赤い線の終わりから学習していることがわかります。
関連記事