首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >[Tensorflow][原创]tensorflow保存PB模型的几种方法总结

[Tensorflow][原创]tensorflow保存PB模型的几种方法总结

作者头像
云未归来
发布2025-07-18 13:45:19
发布2025-07-18 13:45:19
1390
举报

第一种方法:(官方不推荐)

(1)引入库

from tensorflow.python.framework import graph_util

(2)一般在seession初始化全局变量下写这句代码

constant_graph=graph_util.convert_variables_to_constants(sess,

sess.graph_def, ['output_node_name'])

其中output_node_name是输出节点的名称,这个list可以包含输入输入多个节点名称

(3)保存模型:

with tf.gfile.FastGFile('./model.pb', mode='wb') as f:

        f.write(constant_graph.SerializeToString())

第二种方法:(这是官方推荐的)

直接保存模型:

tf.compat.v1.saved_model.simple_save(sess,

            "./saved_model",

            inputs={"input": x, 'keep_prob':keep_prob},

            outputs={"output": y_conv})

第三种方法:

# 保存图表并保存变量参数

from tensorflow.python.framework import graph_util

var_list=tf.global_variables()

constant_graph = graph_util.convert_variables_to_constants(sess,

sess.graph_def,output_node_names=[var_list[i].name for i in range(len(var_list))]) # 保存图表并保存变量参数

tf.train.write_graph(constant_graph, './output', 'expert-graph.pb', as_text=False)

具体参数看这:

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)

# Writes a graph proto to a file.

#      graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.

#      logdir: Directory where to write the graph. This can refer to remote

#        filesystems, such as Google Cloud Storage (GCS).

#      name: Filename for the graph.

#      as_text: If `True`, writes the graph as an ASCII proto.

#    Returns:

#      The path of the output proto file.

(从内置文档摘来的,相信大家都看得懂^_^)

如果要加载的话就用这个:

tf.train.import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None)

#参数如下

#meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including

#       the path) containing a `MetaGraphDef`.

#    clear_devices: Whether or not to clear the device field for an `Operation`

#        or `Tensor` during import.

#     import_scope: Optional `string`. Name scope to add. Only used when

#        initializing from protocol buffer.

#      **kwargs: Optional keyed arguments.

#    Returns:

#      A saver constructed from `saver_def` in `MetaGraphDef` or None.

#      A None value is returned if no variables exist in the `MetaGraphDef`

第四种方法:

# 只保留图表

graph_def = tf.get_default_graph().as_graph_def()

with gfile.GFile('./output/output_graph.pb', 'wb') as f:

    f.write(graph_def.SerializeToString())

# 或者

tf.train.write_graph(graph_def, './output', 'output_graph.pb', as_text=False)

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019-10-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档