前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow:AToolDeveloperGuideToTFModelFIles

tensorflow:AToolDeveloperGuideToTFModelFIles

作者头像
ke1th
发布2018-01-02 11:30:10
1.4K0
发布2018-01-02 11:30:10
举报
文章被收录于专栏:漫漫深度学习路

Tensorflow Model Files

最近闲来无聊,想深入理解一下tensorlfow,也不知从何下手,突然间发现了官方文档的Extend模块下还有这个一片文章 A Tool Developer's Guide to TensorFlow Model Files, 所以就打算边翻译,边学习了。水平有限,如发现错误,请不吝指出!

翻译开始

大多数用户不需要关心tensorflow在硬盘上存储数据的细节问题的,但是如果你是一个 Tool developer, 那就另当别论了。例如,如果你想分析模型(models),或者想在tensorflow或者其它格式之间进行来回转换。这篇指南通过试着去解释一些 如何处理 保存着模型数据的文件的细节,使得开发者们做一些格式装换的工具更加简单。

Protocol Buffers

所有的Tensorflow的文件格式都是基于Protocol Buffers的。所以了解它们是如何工作的是非常有价值的。概括来说就是,你在文本文件(text files)中定义数据结构,protobuf tools就会生成对应的C,Python和其它语言的类。我们可以用友好的方式来加载,保存,访问这些类中的数据。我们经常将 Protocol Buffers称为 protobufs,在接下来的文章中,我们将继续遵守这个约定。

可以看一下我的这篇文章,对protocol buffer进行了简单的介绍

GraphDef

tensorflow中,计算的基础是Graph对象。Graph对象保存着网络的节点,每个节点代表一个Operation(add, matmul, etc),节点之间由输入和输出链接起来。当建好了一个Graph对象之后,可以通过Graph.as_graph_def() 把它保存起来,as_graph_def() 返回一个 GraphDef对象。

GraphDef类 是由ProtoBuf库创建的对象。它的定义在tensorflow/core/framework/graph.protoprotobuf tools解析这个文本文件,然后生成代码用来加载,存储,和操作图定义。如果看到一个独立的 用于表示模型(model)的Tensorflow文件,那么它很可能是 由protobuf code 保存的序列化的GraphDef对象。

protobuf code 用来从硬盘上 保存和加载GraphDef对象。加载对象的代码看起来像是这样:

代码语言:javascript
复制
#这行代码创建了一个空的 GraphDef 对象。GraphDef类已经由 graph.proto 中定义的文本 所创建。
#我们将用文本中的数据来填充这个对象
graph_def = tf.GraphDef()

if FLAGS.input_binary:
    with open("graph_def.pb", "rb") as f:
        graph_def.ParseFromString(f.read())
else:
    with open("graph_def.pb", mode='r') as f
        text_format.Merge(f.read(), graph_def)

译者注:txt_format是一个工具模块,from google.protobuf import text_format 可以引入。 这里只是演示了如何load ProtoBuf,但是,并没有说明如何保存ProtoBuf,如果想要保存的话,tensorflow提供了一个接口 tf.train.write_graph(graph_def, "./", name='graph.pb')。用这个就可以保存成ProtoBuf。 当然,加载的话,tensorflow也提供了一个接口: def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None)

Text or Binary

有两种不同的文件格式可以存储ProtoBuf。一个是TextFormat,人类可以很容易的理解,而且可以很容易的进行debugging或者editing,但是如果里面包含数值数据的话,那么这个文件就会变的很大。这里有一个例子 graph_run_run2.pbtxt 尴尬的是,官方给的这个例子找不到了。。。

另一种文件格式是 BinaryFormat,它比TextFormat所需的存储空间小,但是人类读不懂。在上面提供的脚本文件中,我们要求用户提供 flag 用来指示,我们读取的文件是 TextFormat还是BinaryFormat,这样我们才能够找到正确的方法去调用。这里有一个BinaryFormat的例子inception_v3 archive inception_v3_2016_08_28_frozen.pb.

不过API的设计着实让人懵逼-对于BinaryFormat ,我们调用 ParseFromString(), 对于TextFormat,我们使用text_format模块。

Nodes

一旦将文件加载到graph_def对象,你就可以访问内部的数据了。出于实用目的,最重要的部分是存储节点成员的节点列表。下面的循环代码可以获取到它们:

代码语言:javascript
复制
for node in graph_def.node:
    print(node)

每个节点(node)是一个NodeDef对象,定义在tensorflow/core/framework/node_def.proto.这些节点是TensorflowGraph的基本构件块,每个都定义了一个operation和它的输入连接。

下面将介绍 NodeDef的成员和其所代表的含义。

name

每个节点(Node) 应该有一个唯一的标识符,图中的其它节点不能使用该标识符(这个标识符就是name属性对应的值)。在使用tensorflow Python接口的时候,如果没有显示指定name属性,那么tensorflow会自动选择一个namename的格式是 operation_name加上一个累加的数字。

name用来定义节点之间的连接 ,和在运行时为整个图形设置输入输出。

op

这个属性指明要执行哪个operation,例如"Add", "MatMul", 或者 "Conv2D"。当Graph运行起来的时候,就会在注册表中查找这些op的名称以找到其对应的实现。注册表是通过调用REGISTER_OP() 宏来填充的,就像这些tensorflow/core/ops/nn_ops.cc.

input

一个strings列表,列表中的每个元素是其它节点的名字,可选的在后面跟上一个冒号和输出端口号。例如:一个拥有两个输入的节点的input属性大概是这样的["some_node_name", "another_node_name"], 等价于["some_node_name:0", "another_node_name:0"],说明了,当前node的第一个输入是名字为"some_node_name"Node的第一个输出,当前node的第二个输入是名字为"another_node_name"Node的第一个输出。

我的测试结果是,现在的input在pdtxt中是下面这种形式,而不是文档中所说的 strings list input: “some_node_name” input: “another_node_name”

device

多数情况下,可以忽略这东西。它规定了在分布式情况下,哪个设备执行这个节点,或者是你想强制一个operationCPU上或是GPU上运行。

attr

这个属性保存了key/value键值对,用来指定节点的所有属性。这是一个节点的 永久属性,一旦指定,在运行时刻就不能再被修改了,例如:卷积核的大小,或者是constant op 的值。 由于可能有多种不同类型的属性值,从strings,到int,再到tensor 值的 arrays。这里有单独的protobuf file文件,定义着这些数据结构tensorflow/core/framework/attr_value.proto.

每个属性拥有一个唯一的名字字符串,在定义operation的时候,期望的属性会被列出来。当一个属性没有在node中出现时,但是在定义op的时候,它有一个属性的默认值,那么这个默认值将会在创建图的时候使用。

Python中,你可以 通过调用 node.name, node.op, etc 访问所有的这些成员 。在GraphDef中存储的 节点列表是模型体系结构的完整定义。

Freezing

令人困惑的一点是 在训练过程中,权值通常不保存在 file format 中。 相反,它们被保存在单独地 检查点checkpoint文件中,初始化时,图中的Variable op用于加载最近的值。在部署到生产环境的时候,用于单独的文件通常会不方便。所以,这里有一个freeze_graph.py脚本文件,用于将 graph definition和 一组checkpoints 冻结成一个文件。

在训练过程中,权值通常不保存在 file format 中, 我觉着对这句话更精确的解释是:在训练过程中保存模型的时候,是将 权值保存在 ckpt文件中的,回想一下 Saver, 在训练过程中,权值还是保存在内存中的。

它是怎么做的呢?加载GraphDef,将所有的变量从最近的 检查点文件中取出,然后将GraphDef中的Variable op 替换成 Const op, 这些Const op中保存着 检查点中保存的变量的值。然后,它去掉GraphDef中与 前向过程无关的节点,然后将处理后的GraphDef保存到输出文件中。

部署的时候,用这个玩意感觉爽的很。

Weight Formats

如果你正在处理一些 表示神经网络的 TensorFlow模型,最常见的问题之一就是 提取和 解释权重值。存储它们的常用方法就是,用freeze_graph脚本处理GraphDef,将Variable op 换成 Const op,使用Const op将这些权重作为Tensor存储起来。Tensor被定义在tensorflow/core/framework/tensor.proto, Tensor 中不仅保存了权重的值,还保存了数据类型(int,float)和size。在Python中,可以通过表示 Const opNodeDef对象中获取TensorProto对象,就像

代码语言:javascript
复制
tensorProto = some_node_def.attr['value'].tensor

这段代码会返回一个 表示权重数据的对象。数据本身会保存在一个列表中,这个列表的名字是suffix_val, suffix代表对象的数据类型,例如float_val 代表 32位浮点型。

当在不同的框架之间进行转换时,卷积权重的顺序是很难处理的。在Tensorflow中,Conv2D op的卷积核的存储在第二个输入上,期望的顺序是[filter_height, filter_width, input_depth, output_depth],在这里,filter_count增加一意味着移动到内存中的相邻值。

希望这个纲要能让你更好地了解TensorFlow模型文件中正在发生的事情,如果你需要对它们进行操作的话,将会对你有所帮助。

翻译完毕,总结

本文中提到了以下几个概念:

  • GraphDef
    • GraphDef中存储的节点列表是模型体系结构的完整定义
  • NodeDef
    • 用于代表一个op及其 输入输出
    • name: name属性表示op的名字 name:ouput_index代表输出tensor
    • input: 属性用于暴露op的输入

Demo

下面只是给出了一个简单的代码,这里也有一个示例

保存为pb

代码语言:javascript
复制
import tensorflow as tf
t = tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
paddings = tf.constant([[1,0], [2,2], [1,2]])

paded = tf.pad(t, paddings, "CONSTANT")

graph_def = tf.get_default_graph().as_graph_def()
print(graph_def)

tf.train.write_graph(graph_def, logdir="./", name='graph.pb', as_text=True)

打印出来的结果为:

代码语言:javascript
复制
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 2
          }
          dim {
            size: 2
          }
          dim {
            size: 3
          }
        }
        tensor_content: "\001\000\000\000\002\000\000\000\003\000\000\000\004\000\000\000\005\000\000\000\006\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000\004\000\000\000\005\000\000\000\006\000\000\000"
      }
    }
  }
}
node {
  name: "Const_1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 3
          }
          dim {
            size: 2
          }
        }
        tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000\002\000\000\000\001\000\000\000\002\000\000\000"
      }
    }
  }
}
node {
  name: "Pad"
  op: "Pad"
  input: "Const"
  input: "Const_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tpaddings"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 21
}

解析pb

代码语言:javascript
复制
import tensorflow
from google.protobuf import text_format

graph_def = tf.GraphDef()
#因为是文本文件,所以mode='r',如果之前保存的是二进制文件 mode='rb'
with open("./graph.pb", mode='r') as file:
    text_format.Merge(file.read(), graph_def)

tf.import_graph_def(graph_def=graph_def, name='')

#get_tensor_by_name有一个需要注意的地方,就是 tensor的name需要是 op_name:output_index
padded = tf.get_default_graph().get_tensor_by_name("Pad:0")

with tf.Session() as sess:
    print(sess.run(padded))

当我们用这种方式只进行推断的时候,我们可以这么做:

  • 获取placeholder tensor
  • feed 这些 tensor
  • 获取最后一层的tensor,然后sess.run打印出来结果就 OK

最后说明一下前面用到的几个方法

代码语言:javascript
复制
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None)
# name : 可选的,加在GraphDef中名字的前面,默认是import ,一般情况下,直接 name=''就可以了
# input_map: 没有测试到底是干嘛的,默认值就可以。

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
# logdir: 导出的文件目录
# name: 导出时的文件名
# as_text: 是以Text形式 还是 binary 形式导出, 默认为True
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Tensorflow Model Files
    • 翻译开始
      • Protocol Buffers
        • GraphDef
          • Text or Binary
            • Nodes
              • Freezing
                • Weight Formats
                  • 翻译完毕,总结
                    • Demo
                      • 最后说明一下前面用到的几个方法
                      相关产品与服务
                      GPU 云服务器
                      GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于生成式AI,自动驾驶,深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
                      领券
                      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档