TensorFlow从诞生以来就一直在深度学习框架中稳居老大的位置,虽然自从2018年12月PyTorch 1.0 stable版本正式发布以来,很快减小了差距,但是也难以超越。
TensorFlow的强项在于部署(包括TensorFlow Lite在移动端部署)和运行效率,另外对各种operation的支持特别齐全,基本上你能想到的算子都已经实现好了,直接调用就好。除此之外,Google Brain的各项前沿研究,以及现在DeepMind的很多研究,开源代码肯定都是基于TensorFlow,比如现在很火的AutoML技术等等,所以成为No.1也是自然而然。
但是又不得不吐槽其调试功能,真是太难用了。这也直接导致了TensorFlow的学习曲线异常之陡,和vim的类似,学起来很难很痛苦,但是学好之后,那是相当地爽。
那么,TensorFlow怎么调试呢?使用断点还是print?亦或是高大上的tfdbg?都不是。
由于TensorFlow静态图的设计(eager模式除外,这个后面单独讨论),设置断点根本无法获取实际tensor的值,具体取值都在后台以C++的方式执行。那print呢?也只能打印出tensor的shape信息。tfdbg,这个官方开发的专用工具该行了吧?不过我建议还是不要尝试了,不仅要一点一点敲命令,我在debug大型程序的时候,直接卡死。
对了,还有一种暴力方法,我最开始的时候在使用,就是把tensor拉出来sess.run一把,这样的确可以得到tensor运行的具体值,但是每次要手动改,很麻烦。
好了,神器要出来了:tf.Print. 在老版本的TensorFlow中可以这么用,非常方便:
x = tf.Print(x,[x, x,shape, x[0], …], message=“x debug info”, summarize=100)
其中,x是需要打印的tensor,注意第一个输入是x和输出相同,但其实也可以不同,做一些操作,但一般debug不需要,所以等式左边的输出也是x.
第二个输入在方括号内表示需要打印的东西,可以是tensor x的具体值,或者是其shape,slice,甚至是函数。
第三个输入message用来标识这一处打印,可以自定义字符串。
最后的summarize控制输出元素的数量,比如100就输出x的前100个元素。
对于新版的TensorFlow,使用tf.print,语法如下:
print_op = tf.print(x)
withtf.control_dependencies([print_op]):
out = tf.add(x, x)
sess.run(out)
很方便吧?
虽然不如直接在PyCharm中设置断点方便,但能把tensor打印出来定位问题也就容易多了。当然,如果是学习代码,想单步跟踪,建议使用eager模式,这就和PyTorch的方式非常相近了,当然,牺牲的是运行效率。