在TensorFlow中,可以通过使用tf.train.Saver
类来保存模型的pb文件。在保存pb文件之前,可以通过os.path.dirname()
函数访问文件夹名称。
具体步骤如下:
import tensorflow as tf
import os
# 定义模型
x = tf.placeholder(tf.float32, name='input')
y = tf.add(x, 1, name='output')
# 定义变量
var1 = tf.Variable(0, name='var1')
var2 = tf.Variable(1, name='var2')
saver = tf.train.Saver()
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 保存模型
save_path = saver.save(sess, os.path.join(os.path.dirname(__file__), 'model.pb'))
print("Model saved in file: %s" % save_path)
在上述代码中,os.path.dirname(__file__)
用于获取当前文件所在的文件夹路径,然后通过os.path.join()
函数将文件夹路径和文件名拼接起来,最终得到保存pb文件的完整路径。
这样,就可以在保存pb文件之前访问文件夹名称,并将模型保存在指定的文件夹中。
注意:以上代码示例中的__file__
是Python中的内置变量,表示当前脚本的文件名。如果代码不在脚本中运行,而是在交互式环境中运行,可以手动指定文件夹路径。
领取专属 10元无门槛券
手把手带您无忧上云