【新智元导读】文本将介绍一些 TensorFlow 的操作技巧,旨在提高你的模型性能和训练水平。文章将从预处理和输入管道开始,覆盖图、调试和性能优化的问题。
预处理和输入管道
保持预处理干净简洁
训练一个相对简单的模型也需要很长时间?检查一下你的预处理!任何麻烦的预处理(比如将数据转换成神经网络的输入),都会显著降低你的推理速度。对于我个人来说,我会创建所谓的“距离地图”(distant map),也就是用于“深层交互对象选择”的灰度图像作为附加输入,使用自定义python函数。我的训练速度最高是每秒大约处理 2.4 幅图像,切换到更强大的GTX 1080 后也没有提升。后来我注意到这个瓶颈,修复后训练速度就变成每秒50幅图像。
当你注意到这样的瓶颈时,一般首先会想到优化代码。但是,将计算时间从你的训练管道中去除还有一个更有效的方法,那就是将预处理移动到生成TFRecord文件的一次性操作当中。繁重的预处理只需执行一次,就能为所有的训练数据创建 TFRecords,你的管道本质上做的也就是加载记录。就算你想引入某种随机性来增强数据,一次创建不同的版本,而不是让你的管道变得庞大臃肿也是值得考虑的,不是吗?
注意队列
有一种发现昂贵的预处理管道的方法是查看 Tensorboard 的队列图。如果你使用框架 QueueRunners并将摘要存储在文件中,这些图都是自动生成的。这些图会显示你的计算机是否能够保持队列处在排满的状态。如果你发现图当中出现了负峰值,则系统无法在计算机要处理一个批次的时间内生成新的数据。其中的一个原因上面已经说过了。根据我的经验,最常见的原因是 min_after_dequeue 值很大。如果队列试图在内存中保留大量记录,你的容量很容易就饱和了,这会导致交换(swapping),并且显著降低队列的速度。其他的原因还包括硬盘问题(例如磁盘速度慢),以及单纯的是数据大,大过了你系统可以处理的程度。无论原因为何,修复这个问题都会加快你的训练过程。
图(graph)的构建和训练
把图固定
TensorFlows把图的构建和图的计算模型分开处理,这在日常编程中是非常罕见的,可能会导致初学者产生一些混乱。例如调试和发送错误消息,可能最初构建图的时候在代码里出现一次,然后在实际评估的时候又出现一次,当你习惯于代码只被评估一次后,这就有些别扭。
另一个问题是图的构建是和训练回路(loop)结合在一起的。这些循环通常是“标准”的python循环,因此可以改变图并向其中添加新的操作。在连续评估图的过程中对图进行改动,会产生重大的性能损失,但这一点在最开始的时候很难注意到。幸运的是这很容易解决。只需要在开始训练循环之前,把图固定(finalize)就行——调用tf.getDefaultGraph().finalize() 把图锁定,之后想要添加任何新的操作都会产生错误。看吧,问题解决了。
彻底分析图
实际上 TensorFlow 的分析功能是很强的,不过似乎没有得到那么多宣传。TensorFlow 里有一种机制,可以记录图操作的运行时间和内存消耗。如果你正在寻找瓶颈在哪里,或者需要弄清你的机器不更换硬盘驱动器的话能不能运行一个模型,这个功能就可以派上用场了。
要生成分析数据,你需要在启用跟踪的情况下把图整个跑一遍:
之后,一个 timeline.json 文件会被保存到当前文件夹,跟踪数据可以在 Tensorboard 找到。现在,你可以很容易地看到一个操作花了多长时间来计算,以及这个操作消耗了多少内存。打开Tensorboard的图视图,选择左侧的最新运行,你就能在右边看到性能的详细信息。一方面,这方便你调整模型,尽可能多地使用机器;另一方面,这方便你在训练管道中发现瓶颈。如果你更喜欢时间轴视图,在 Google Chromes 跟踪事件分析工具(Trace Event Profiling Tool)中加载timeline.json 文件就行了。
另一个不错的工具是 tfprof,tfprof 使用相同的功能做内存和执行时间分析,不过提供了更多的便利功能(feature)。额外的统计信息需要更改代码。
注意内存
就像上一节说的那样,分析可以让你了解特定操作的内存使用情况。但是,观察整个模型的内存消耗更加重要。你必须确保不会超过你机器的内存,因为 swapping 绝对会降低你输入管道的速度,这样 GPU 就会等着处理新的数据。要检测这种行为,用简单的 top 或者 Tensorboard 队列图应该足够了。要详细研究可以参照前面说的方法。
调试
善用打印
在调试问题时,比如停滞丢失或产生了奇怪的输出,我主要使用的工具是 tf.Print。考虑到神经网络的性质,看你的模型里面张量的原始值一般没有什么意义。没有人能看懂数百万的浮点数,看出什么地方错了。但是,有些方法,尤其是把形状或平均值打印出来,就能提供很多的信息。如果你要实现一些现有的模型,把东西打印出来能让你把模型的值和论文或文章里的值进行比较,还能帮助你解决一些棘手的问题,或者论文里的拼写错误。
TensorFlow 1.0 推出了新的 TFDebugger,看起来很有用。我现在还没有使用这个功能,但接下来几个星期肯定会用。
设置一个操作执行超时
好,现在你已经实现了你的模型,session 也启动了,但没有事情都没有什么发生?这通常是由空队列引起的。但是,如果你不知道是哪一个队列导致的,那么有一个简单的修复方法:只需在创建会话时启用一个操作执行超时,这样当操作超过限制时,脚本就会崩溃:
使用堆栈跟踪,你就可以找出是哪个操作产生了问题,修复错误,继续训练吧。
希望这篇文章对同样使用 TensorFlow 的你有用。如果你发现了错误,或者有建议或意见,欢迎在评论里和大家分享哦~~
编译来源:
http://www.deeplearningweekly.com/blog/tensorflow-quick-tips