由于大型神经网络的训练往往耗费很长的时间,可能会因为机器损坏、断电或系统崩溃等各种因素无法一次性完成模型训练而导致前面所有的训练功亏一篑。本次来介绍一种检查点机制,在训练过程中保存更新的权值到检查点文件,而再次训练时恢复检查点文件中的权值数据,继续训练模型。这样能有效的防止上述情况的发生。
首先用ipython notebook打开上一次的代码,并找到get_sart函数,在with tf.Session() as sess:后面插入一行:saver = tf.train.Saver()新建一个saver对象用于保存训练过程中的权值信息。然后再往下找到if i % 2 == 0: 插入一行:saver.save(sess,'my-model', global_step=i)表示每训练两步就将当前的会话信息(包括当前步骤的权值和偏置项)存入本地检查点文件my-model-i中,例如第二步就是my-model-2,第四步就是my-model-4等。下面来调用get_sart函数看结果:
这一次训练完前20步,我们认为中断训练过程,模拟上述的意外情况发生。来看一下saver对象保存的检查点文件,当不指定保存路径时默认存在当前目录下,即代码文件所在的目录,如下:
上图只显示了从my-model-12到20这5个文件,因为saver默认保存最后5步的检查点文件。接下来要实现接着第20步的训练结果继续训练余下的10步,下面给出完整的get_sart函数代码:
这里可以看出model_checkpoint_path是上次训练的最后一步检查点文件路径。
然后用if检查一下ckpt变量是否存在,如果存在则用saver.restore(sess, ckpt.model_checkpoint_path)恢复上次训练最后一步迭代的权值数据,保证了本次训练能够接着上次开始。接着更新initial_step把它重置为上次的最后一步。如果ckpt不存在,比如第一次训练时,才需要初始化所有变量,注意:如果在restore载入权值数据之前进行变量初始化将会报错。rsplit函数返回的是一个列表:
接下来开始训练模型,仍然每隔两步保存检查点文件,最后训练结果如下:
第二次仍然在当前目录生成了最后5步的检查点文件:
如上,tensorflow载入的参数信息来自my-model-20这个文件,并接着第20步完成了模型训练。本文只更新了get_start函数,其他函数代码与上一节相同。
历史文章回归
领取专属 10元无门槛券
私享最新 技术干货