当将numpy数组输入估计器时,有一种很好的方法来设置纪元。
tf.estimator.inputs.numpy_input_fn(
x,
y=None,
batch_size=128,
num_epochs=1 ,
shuffle=None,
queue_capacity=1000,
num_threads=1
)
但我不能用TFRecords找到类似的方法,大多数人似乎只是把它放在一个循环中
i = 0
while ( i < 100000):
model.train(input_fn=input_fn, steps=100)
有没有一种干净的方法可以用估计器显式地设置TFRecords的纪元数?
发布于 2019-03-23 21:20:42
您可以使用dataset.repeat(num_epochs)
设置纪元数。数据集管道输出数据集对象,即批处理大小的元组(要素、标注),并将其输入到model.train()
dataset = tf.data.TFRecordDataset(file.tfrecords)
dataset = tf.shuffle().repeat()
...
dataset = dataset.batch()
为了使其正常工作,您设置了model.train(steps=None, max_steps=None)
在本例中,您让Dataset API通过在到达num_epoch时生成tf.errors.OutOfRange
错误或StopIteration
异常来处理纪元计数。
https://stackoverflow.com/questions/55312887
复制相似问题