“ 再小的你,也有自己的个性”
通过阅读本文,您可以学会:
1、如何保存训练好的模型
2、如何加载训练好的模型
源代码:
https://github.com/PrivateThink/tensorflow_tutorial
在前面的教程,学习了如何训练模型,如果数据量大的话,是要花很长时间来训练的,重复训练即浪费时间又浪费资源,所以在很多机器学习任务中,会将训练好的模型保存下来,下次直接加载保存的模型。
数据准备和模型创建
上述程序,从文件中读取数据,构建线性模型,利用交叉熵损失函数,采用Adam优化器,最后计算准确率。
Tensorflow中用tf.train.Saver来声明保存训练好模型的操作。初始化以后就可以进行训练和测试了。
上述程序中training_enpochs=10,只训练10次,每两次打印一次训练结果。训练完以后,就可以用saver.save保存模型。
第一次训练结果
经过第一次训练,测试准确率可达0.958.
保存模型很简单,同样加载模型也很简单。
上述将训练次数training_epochs改为20,继续迭代,用saver.restore就可以加载模型了,然后继续训练。
第二次训练结果
第二次的预测准确率为0.975,比第一次预测的准确要好。今天的教程就讲到了。
后续持续更新Tensorflow教程
欢迎关注和分享
领取专属 10元无门槛券
私享最新 技术干货