首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >SessionRunHook的成员函数的调用顺序是什么?

SessionRunHook的成员函数的调用顺序是什么?
EN

Stack Overflow用户
提问于 2017-08-06 13:11:01
回答 2查看 9.5K关注 0票数 13

看完API DOC后,我也不明白SessionRunHook的用法。例如,调用SessionRunHook的成员函数的顺序是什么?是after_create_session -> before_run -> begin -> after_run -> end吗?我找不到有详细示例的教程,有没有更详细的解释?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-10-17 16:42:52

你可以找到一个教程here,有点长,但你可以跳过构建网络的部分。或者你可以阅读我下面的小总结,基于我的经验。

首先,应该使用MonitoredSession而不是普通的Session

A SessionRunHook扩展了session.run()MonitoredSession的调用。

然后,可以在here中找到一些常见的SessionRunHook类。一个简单的示例是LoggingTensorHook,但是您可能希望在导入后添加以下行,以便在运行时查看日志:

代码语言:javascript
运行
AI代码解释
复制
tf.logging.set_verbosity(tf.logging.INFO)

或者,您可以选择实现自己的SessionRunHook类。一个简单的例子来自cifar10 tutorial

代码语言:javascript
运行
AI代码解释
复制
class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

其中loss是在类外部定义的。此_LoggerHook使用print打印信息,而LoggingTensorHook使用tf.logging.INFO打印信息。

最后,为了更好地理解它的工作原理,使用MonitoredSession here以伪代码的形式给出了执行顺序

代码语言:javascript
运行
AI代码解释
复制
  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

希望这能有所帮助。

票数 28
EN

Stack Overflow用户

发布于 2018-05-11 01:12:54

tf.SessionRunHook使您能够在代码中执行的每个会话运行命令期间添加自定义代码。为了理解它,我在下面创建了一个简单的示例:

  1. 我们希望在每次更新参数后打印损失值。
  2. 我们将使用SessionRunHook来实现这一点。

创建张量流图

代码语言:javascript
运行
AI代码解释
复制
import tensorflow as tf
import numpy as np

x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)

创建钩子的

代码语言:javascript
运行
AI代码解释
复制
class _Hook(tf.train.SessionRunHook):
  def __init__(self, loss):
    self.loss = loss

  def begin(self):
    pass

  def before_run(self, run_context):
    return tf.train.SessionRunArgs(self.loss)  

  def after_run(self, run_context, run_values):
    loss_value = run_values.results
    print("loss value:", loss_value)

使用hook 创建受监视的会话

代码语言:javascript
运行
AI代码解释
复制
sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])

列车

代码语言:javascript
运行
AI代码解释
复制
for _ in range(10):
  x_ = np.random.random((10, 2))
  sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494
票数 9
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45532365

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文