我已经训练了一个网络(使用GPU),现在我想在CPU上运行它(用于推理)。为此,我使用以下代码加载元图,然后加载网络参数。
config = tf.ConfigProto(
device_count = {'GPU': 0}
)
sess = tf.Session(config=config)
meta_graph=".../graph-0207-190023.meta"
model=".../model.data-00000-of-00001"
new_saver = tf.train.import_meta_graph(meta_graph)
new_saver.restore(sess, model)
问题是,由于该图是为训练而定义的,因此我使用了一些不在CPU上运行的特定操作。例如"MaxBytesInUse“https://www.tensorflow.org/api_docs/python/tf/contrib/memory_stats/MaxBytesInUse,它记录图形处理器的活动。
这就是为什么,当我尝试运行这段代码时,我得到了以下错误:
InvalidArgumentError: No OpKernel was registered to support Op 'MaxBytesInUse' with these attrs. Registered devices: [CPU], Registered kernels:
device='GPU'
[[Node: PeakMemoryTracker/MaxBytesInUse = MaxBytesInUse[_device="/device:GPU:0"]()]]
有没有一种简单的方法可以删除特定的GPU相关操作并在CPU上运行图形?
发布于 2019-02-08 12:18:50
我想像这样的东西应该能解决你的问题
import tensorflow as tf
def remove_no_cpu_ops(graph_def):
# Remove all ops that cannot run on the CPU
removed = set()
nodes = list(graph_def.node)
for node in nodes:
if not can_run_on_cpu(node):
graph_def.node.remove(node)
removed.add(node.name)
# Recursively remove ops depending on removed ops
while removed:
removed, prev_removed = set(), removed
nodes = list(graph_def.node)
for node in graph_def.node:
if any(inp in prev_removed for inp in node.input):
graph_def.node.remove(node)
removed.add(node.name)
def can_run_on_cpu(node):
# Check if there is a CPU kernel for the node operation
from tensorflow.python.framework import kernels
for kernel in kernels.get_registered_kernels_for_op(node.op).kernel:
if kernel.device_type == 'CPU':
return True
return False
config = tf.ConfigProto(
device_count = {'GPU': 0}
)
sess = tf.Session(config=config)
meta_graph = ".../graph-0207-190023.meta"
model = ".../model.data-00000-of-00001"
# Load metagraph definition
meta_graph_def = tf.MetaGraphDef()
with open(meta_graph, 'rb') as f:
meta_graph_def.MergeFromString(f.read())
# Remove GPU ops
remove_no_cpu_ops(meta_graph_def.graph_def)
# Make saver from modified metagraph definition
new_saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True)
new_saver.restore(sess, model)
其思想是遍历图形定义中的所有节点,并删除那些没有CPU内核的节点。实际上,您可以通过检查是否存在用于节点操作和输入类型的CPU内核,检查内核定义的constraint
字段,从而使can_run_on_cpu
更准确,但这对于您的情况可能已经足够了。我还向tf.train.import_meta_graph
添加了一个clear_devices=True
,它在强制设备在特定设备上运行的操作中清除设备指令(如果您的图形中有这些指令的话)。
https://stackoverflow.com/questions/54590442
复制相似问题