看完API DOC后,我也不明白SessionRunHook的用法。例如,调用SessionRunHook的成员函数的顺序是什么?是after_create_session -> before_run -> begin -> after_run -> end
吗?我找不到有详细示例的教程,有没有更详细的解释?
发布于 2017-10-17 16:42:52
你可以找到一个教程here,有点长,但你可以跳过构建网络的部分。或者你可以阅读我下面的小总结,基于我的经验。
首先,应该使用MonitoredSession
而不是普通的Session
。
A SessionRunHook扩展了
session.run()
对MonitoredSession
的调用。
然后,可以在here中找到一些常见的SessionRunHook
类。一个简单的示例是LoggingTensorHook
,但是您可能希望在导入后添加以下行,以便在运行时查看日志:
tf.logging.set_verbosity(tf.logging.INFO)
或者,您可以选择实现自己的SessionRunHook
类。一个简单的例子来自cifar10 tutorial
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
其中loss
是在类外部定义的。此_LoggerHook
使用print
打印信息,而LoggingTensorHook
使用tf.logging.INFO
打印信息。
最后,为了更好地理解它的工作原理,使用MonitoredSession
here以伪代码的形式给出了执行顺序
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
希望这能有所帮助。
发布于 2018-05-11 01:12:54
tf.SessionRunHook
使您能够在代码中执行的每个会话运行命令期间添加自定义代码。为了理解它,我在下面创建了一个简单的示例:
SessionRunHook
来实现这一点。创建张量流图
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
创建钩子的
class _Hook(tf.train.SessionRunHook):
def __init__(self, loss):
self.loss = loss
def begin(self):
pass
def before_run(self, run_context):
return tf.train.SessionRunArgs(self.loss)
def after_run(self, run_context, run_values):
loss_value = run_values.results
print("loss value:", loss_value)
使用hook
创建受监视的会话
sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])
列车
for _ in range(10):
x_ = np.random.random((10, 2))
sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494
https://stackoverflow.com/questions/45532365
复制相似问题