首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

TensorFlow入门必看:Google AI实习生经验谈

作者:JacobBuckman译者:王强、无明【新智元导读】本文作者Jacob来自GoogleAIResident项目,他在2017年夏天开启了为期一年的Google研究型实习,在此之前他虽然有很多编程经验和机器学习经验,但没有使用过TensorFlow。这篇文章是Jacob为TensorFlow写的一个实用教程,作者表示,要是在开启TensorFlow学习之前有人告诉他这些知识就好了。

前言:“我叫Jacob,是谷歌AIResidency项目的学者。2017年夏天我进入这个项目的时候,我自己的编程经验很丰富,对机器学习理解也很深刻,但以前我从未使用过Tensorflow。当时我认为凭自己的能力可以很快掌握Tensorflow,但没想到我学习它的过程竟然如此跌宕起伏。甚至加入项目几个月后我还偶尔会感到困惑,不知道怎样用Tensorflow代码实现自己的新想法。

快看!我们得到了一个节点,它包含常量:2。我知道你很惊讶,惊讶的是一个名为tf.constant的函数。当我们打印这个变量时,我们看到它返回一个tf.Tensor对象,它是一个指向我们刚创建的节点的指针。为了强调这一点,这里是另一个例子:

每次我们调用tf.constant的时候,我们都会在图中创建一个新节点。即使节点在功能上与现有节点完全相同,即使我们将节点重新分配给同一个变量,甚至我们根本没有将其分配给变量,结果都一样。相反,如果创建一个新变量并将其设置为与现有节点相等,则只需将该指针复制到该节点,并且不会向该图添加任何内容:

好的,我们更进一步。

现在我们来看——这才是我们要的真正的计算图表!请注意,+操作在Tensorflow中过载,所以同时添加两个张量会在图中增加一个节点,尽管它看起来不像是Tensorflow操作。好的,所以two_node指向包含2的节点,three_node指向包含3的节点,而sum_node指向包含...+的节点?什么情况?它不是应该包含5吗?事实证明,没有。

精彩!我们还可以传递一个列表,sess.run([node1,node2,...]),并让它返回多个输出:

一般来说,sess.run调用往往是最大的TensorFlow瓶颈之一,所以调用它的次数越少越好。可以的话在一个sess.run调用中返回多个项目,而不是进行多个调用。占位符和feed_dict我们迄今为止所做的计算一直很乏味:没有机会获得输入,所以它们总是输出相同的东西。一个实用的应用可能涉及构建这样一个计算图:它接受输入,以某种(一致)方式处理它,并返回一个输出。

……这是个糟糕的例子,因为它引发了一个异常。占位符预计会被赋予一个值,但我们没有提供,因此Tensorflow崩溃了。为了提供一个值,我们使用sess.run的feed_dict属性。

好多了。注意传递给feed_dict的数值格式。这些键应该是与图中占位符节点相对应的变量(如前所述,它实际上意味着指向图中占位符节点的指针)。相应的值是要分配给每个占位符的数据元素——通常是标量或Numpy数组。第三个关键抽象:计算路径下面是另一个使用占位符的例子:

为什么第二次调用sess.run会失败?我们并没有在检查input_placeholder,为什么会引发与input_placeholder相关的错误?答案在于最终的关键Tensorflow抽象:计算路径。还好这个抽象非常直观。当我们在依赖于图中其他节点的节点上调用sess.run时,我们也需要计算这些节点的值。

所有三个节点都需要评估以计算sum_node的值。最重要的是,这里面包含了我们未填充的占位符,并解释了例外情况!相反,考察three_node的计算路径:

根据图的结构,我们不需要计算所有的节点也可以评估我们想要的节点!因为我们不需要评估placeholder_node来评估three_node,所以运行sess.run(three_node)不会引发异常。Tensorflow仅通过必需的节点自动路由计算这一事实是它的巨大优势。如果计算图非常大并且有许多不必要的节点,它就能节约大量运行时间。

发现另一个异常。一个变量节点在首次创建时,它的值基本上就是“null”,任何尝试对它进行计算的操作都会抛出这个异常。我们只能先给一个变量赋值后才能用它做计算。有两种主要方法可以用于给变量赋值:初始化器和tf.assign。我们先看看tf.assign:

与我们迄今为止看到的节点相比,tf.assign(target,value)有一些独特的属性:标识操作。tf.assign(target,value)不做任何计算,它总是与value相等。副作用。当计算“流经”assign_node时,就会给图中的其他节点带来副作用。在这种情况下,副作用就是用保存在zero_node中的值替换count_variable的值。非依赖边。

当计算流经图中的任何节点时,它还会让该节点控制的副作用(绿色所示)起效。由于tf.assign的特殊副作用,与count_variable(之前为“null”)关联的内存现在被永久设置为0。这意味着,当我们下一次调用sess.run(count_variable)时,不会抛出任何异常。相反,我们将得到0。接下来,让我们来看看初始化器:

这里都发生了什么?为什么初始化器不起作用?问题在于会话和图之间的分隔。我们已经将get_variable的initializer属性指向const_init_node,但它只是在图中的节点之间添加了一个新的连接。我们还没有做任何与导致异常有关的事情:与变量节点(保存在会话中,而不是图中)相关联的内存仍然为“null”。我们需要通过会话让const_init_node更新变量。

为此,我们添加了另一个特殊节点:init=tf.global_variables_initializer。与tf.assign类似,这是一个带有副作用的节点。与tf.assign不一样的是,我们实际上并不需要指定它的输入!tf.global_variables_initializer将在其创建时查看全局图,自动将依赖关系添加到图中的每个tf.initializer上。

正如你所看到的,损失基本上没有变化,而且我们对真实参数有了很好的估计。这部分代码只有一两行对你来说是新的:既然你对Tensorflow的基本概念已经有了很好的理解,这段代码应该很容易解释!第一行,optimizer=tf.train.GradientDescentOptimizer(1e-3)不会向图中添加节点。它只是创建了一个Python对象,包含了一些有用的函数。

我们看到了结果是5。但是,如果我们想检查中间值two_node和three_node,该怎么办?检查中间值的一种方法是向sess.run添加一个返回参数,该参数指向要检查的每个中间节点,然后在返回后打印它。

这样做通常没有问题,但当代码变得越来越复杂时,这可能有点尴尬。更方便的方法是使用tf.Print语句。令人困惑的是,tf.Print实际上是Tensorflow的一种节点,它有输出和副作用!它有两个必需的参数:一个要复制的节点和一个要打印的内容列表。“要复制的节点”可以是图中的任何节点,tf.Print是与“要复制的节点”相关的标识操作,也就是说,它将输出其输入的副本。

有关tf.Print的一个重要却有些微妙的点:打印其实只是它的一个副作用。与所有其他副作用一样,只有在计算流经tf.Print节点时才会进行打印。如果tf.Print节点不在计算路径中,则不会打印任何内容。即使tf.Print节点正在复制的原始节点位于计算路径上,但tf.Print节点本身可能不是。这个问题要注意!

这里(https://wookayin.github.io/tensorflow-talk-debugging/#1)有一个很好的资源,提供了更多实用的调试建议。结论希望这篇文章能够帮助你更好地理解Tensorflow,了解它的工作原理以及如何使用它。毕竟,这里介绍的概念对所有Tensorflow程序来说都很重要,但这些还都只是表面上的东西。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180703A02THS00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券