文章作者:张舒婷,经授权发布。
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用 tensorflow 提供的队列 queue,也就是第二种方法从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述。而 TFRecords 是tensorflow 的内定标准形式,更加高效的读取方法。 Tensorflow 读取数据的三种方式:
feed_dict{}
的数据不可以是tensor格式,会引起错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
关于这个问题的讨论,在tensorflow也开了issue,但并没有解决方案。在TensorFlow官方github文档中example.proto的文件,这个文件详细说明了TensorFlow里面的example协议。tensorflow的example包含的是基于key-value对的存储方法,其中key是一个字符串,其映射到的是feature信息,feature包含三种类型:
BytesList:字符串列表 FloatList:浮点数列表 Int64List:64位整数列表
以上三种类型都是列表类型,意味着都能够进行拓展,但是也是因为这种弹性格式,所以在解析的时候,需要制定解析参数。
在TensorFlow中,example是按行读取,比如存储 M×NM×N矩阵,使用ByteList存储的话,需要M×NM×N大小的列表,按照每一行的读取方式存放。 官方样例:
An Example for a movie recommendation application:
features {
feature {
key: "age"
value { float_list {
value: 29.0
}}
}
feature {
key: "movie"
value { bytes_list {
value: "The Shawshank Redemption"
value: "Fight Club"
}}
}
feature {
key: "movie_ratings"
value { float_list {
value: 9.0
value: 9.7
}}
}
feature {
key: "suggestion"
value { bytes_list {
value: "Inception"
}}
}
除了单个example构成的feature外,tensorflow还提供了sequence example,表示一个或者多个example,同时还包括上下文context,其中,context表示的是feature_lists的总体特征,如数据集的长度等,feature_list包含一个key,一个value,value表示的是features集合(feature_lists),同样,官方源码也给出了sequence_example的例子:
context: {
feature: {
key : "locale"
value: {
bytes_list: {
value: [ "pt_BR" ]
}
}
}
feature: {
key : "age"
value: {
float_list: {
value: [ 19.0 ]
}
}
}
feature: {
key : "favorites"
value: {
bytes_list: {
value: [ "Majesty Rose", "Savannah Outen", "One Direction" ]
}
}
}
}
feature_lists: {
feature_list: {
key : "movie_ratings"
value: {
feature: {
float_list: {
value: [ 4.5 ]
}
}
feature: {
float_list: {
value: [ 5.0 ]
}
}
}
}
feature_list: {
key : "movie_names"
value: {
feature: {
bytes_list: {
value: [ "The Shawshank Redemption" ]
}
}
feature: {
bytes_list: {
value: [ "Fight Club" ]
}
}
}
}
feature_list: {
key : "actors"
value: {
feature: {
bytes_list: {
value: [ "Tim Robbins", "Morgan Freeman" ]
}
}
feature: {
bytes_list: {
value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ]
}
}
}
}
}
除此之外,官网还有一些其他一致性的例子,可供参考。
captions_train2017.json主结构 { “info”: info, “licenses”: [license], “images”: [image], “annotations”: [annotation] } [info] { “year”: int, “version”: str, “description”: str, “contributor”: str, “url”: str, “date_created”: datetime, } [license] { “id”: int, “name”: str, “url”: str, } [image] { “id”: int, “width”: int, “height”: int, “file_name”: str, “license”: int, “flickr_url”: str, “coco_url”: str, “date_captured”: datetime, } [annotation] { “id”: int, “image_id”: int, “caption”: str } 例如 { “image_id”: 179765, “id”: 38, “caption”: “A black Honda motorcycle parked in front of a garage.” }
形成TFRecord则可以用如下代码(来自于tensorflow/models/research/im2txt/build_coco):
### other function
def _int64_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""Wrapper for inserting a bytes Feature into a SequenceExample proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
def _int64_feature_list(values):
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
def _bytes_feature_list(values):
"""Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
###
with tf.gfile.FastGFile(image.filename, "r") as f:
encoded_image = f.read()
try:
decoder.decode_jpeg(encoded_image) ## tensor->nparray
except (tf.errors.InvalidArgumentError, AssertionError):
print("Skipping file with invalid JPEG data: %s" % image.filename)
return
context = tf.train.Features(feature={
"image/image_id": _int64_feature(image.image_id),
"image/data": _bytes_feature(encoded_image),
}) ## context config
assert len(image.captions) == 1
caption = image.captions[0]
caption_ids = [vocab.word_to_id(word) for word in caption]
feature_lists = tf.train.FeatureLists(feature_list={
"image/caption": _bytes_feature_list(caption),
"image/caption_ids": _int64_feature_list(caption_ids)
}) ## feature list config
sequence_example = tf.train.SequenceExample(
context=context, feature_lists=feature_lists) ## sequence Example
实际处理过程中抽取了json文件中的”image_id”, “filename”, “captions”三个key 对应的值写成example,同时还引入了线程处理文件,以加快速度。
TFRecord解析函数常用的有三个:分别是tf.parse_example
, tf.parse_single_example
, tf.parse_single_sequence_example
,接下来分别介绍:
parse_example的方法定义:
def parse_example(serialized, features, name=None, example_names=None)
parse_example是把example解析为词典型的tensor 参数含义: serialized:一个batch的序列化的example features:解析example的规则 name:当前操作的名字 example_name:当前解析example的proto名称tf.parse_single_example
较parse_example
少了batch的参数,每一次只解析一个example。
这里重点要说的是第二个参数,也就是features,features是把serialized的example中按照键值映射到三种tensor: 1,VarlenFeature 2, SparseFeature 3,FixedLenFeature,下面对这三种映射方式做一个简要的叙述:
# serialized data
serialized = [
features
{ feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
features
{ feature []},
features
{ feature { key: "ft" value { float_list { value: [3.0] } } }
]
# VarlenFeature的使用方法是
features={
"ft":tf.VarLenFeature(tf.float32)
}
# get
{"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
values=[1.0, 2.0, 3.0],
dense_shape=(3, 2)) }
[serilized.size(),df.shape]
的矩阵,这里的FixLenFeature指的是每个键值对应的feature的size是一样的:# serialized data as before # FixedLenFeature usage features: { "ft": FixedLenFeature([2], dtype=tf.float32, default_value=-1), } # get <maybe cause run error for varlenSeries> {"ft": [[1.0, 2.0], [3.0, -1.0]]}tf.parse_single_sequence_example
对应解析sequenceExample,具体例子如下,对于不定长数据,tensorflow也提供自动补齐的功能,在tf.train.batch, tf.train.batch_join, tf.train.shuffle_batch, tf.train.shuffle_batch_join
相关的函数中:import tensorflow as tf
import os
keys=[[1.0,2.0],[2.0,3.0]]
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
def make_example(locale,age,score,times):
example = tf.train.SequenceExample(
context=tf.train.Features(
feature={ "locale":tf.train.Feature(bytes_list=tf.train.BytesList(value=[locale])), "age":tf.train.Feature(int64_list=tf.train.Int64List(value=[age]))
}),
feature_lists=tf.train.FeatureLists(
feature_list={
"movie_rating":tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=score)) for i in range(times)])
}
)
)
return example.SerializeToString()
context_features = {
"locale": tf.FixedLenFeature([],dtype=tf.string),
"age": tf.FixedLenFeature([],dtype=tf.int64)
}
sequence_features = {
"movie_rating": tf.FixedLenSequenceFeature([3], dtype=tf.float32,allow_missing=True)
}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(make_example("china",24,[1.0,3.5,4.0],2),context_features=context_features,sequence_features=sequence_features)
print tf.contrib.learn.run_n(context_parsed)
print tf.contrib.learn.run_n(sequence_parsed)
# get
[{'locale': 'china', 'age': 24}]
[{'movie_rating': array([[ 1. , 3.5, 4. ],
[ 1. , 3.5, 4. ]], dtype=float32)}]
上图的过程是先创建一个先入先出的队列(FIFOQueue),并将其内部所有元素初始化为零。然后构建TensorFlow图,它从队列前端取走一个元素,加上1之后,放回队列的后端ref1,ref2。
除了先入先出队列,tensorflow还提供RandomShuffleQueue
实现异步计算。在实现过程中,需要所有线程都必须能被同步终止,异常必须能被正确捕获并报告,Session终止的时候, 队列必须能被正确地关闭。为了保证上述过程正常进行,Tensorflow提供了tf.Coordinator
和 tf.QueueRunner
两个实现多线程。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。
should_stop()
: 如果线程应该停止则返回True。
request_stop(<exception>)
: 请求该线程停止outofRangeError
。
join(<list of threads>)
: 等待被指定的线程终止。首先创建一个Coordinator对象,然后建立一个或多个使用Coordinator对象的线程。这些线程通常一直循环运行,一直到should_stop()返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop(),同时其他线程的should_stop()将会返回True,然后所有线程都停下来。
这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。此外,一个QueueRunner会运行一个closer thread,当Coordinator收到异常报告时,这个closer thread会自动关闭队列。
example = ...ops to create one example...
# Create a queue, and an op that enqueues examples one at a time in the queue.
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
# Create a training graph that starts by dequeuing a batch of examples.
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...
tf.TFRecordWriter
假设serilized_object是一个已经序列化好的example,那么其写的过程如下:
writer = tf.python_io.TFRecordWriter(filename)
writer.write(serilized_object)
writer.close()
tf.TFRecordReader
在上图中,首先由一个单线程把文件名堆入FIFO队列,两个Reader同时从队列中取文件名并读取数据,Decoder(parse)读出的数据解析后堆入样本队列,最后单个或批量取出样本(图中没有展示样本出列)。
具体的将文件名列表交给tf.train.string_input_producer
函数生成一个先入先出的队列, 文件阅读器会需要它来读取数据。
string_input_producer
提供的可配置参数来设置文件名乱序和最大的训练迭代数, QueueRunner
会为每次迭代(epoch)将所有的文件名加入文件名队列中, 如果shuffle=True
的话, 会对文件名进行乱序处理。这一过程是比较均匀的,因此它可以产生均衡的文件名队列。
这个QueueRunner
的工作线程是独立于文件阅读器的线程, 因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。
在上图中数据输入流图的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用tf.train.batch
, tf.train.batch_join
, tf.train.shuffle_batch
, tf.train.shuffle_batch_join
函数来对队列中的样本进行处理,Batch 读取实例:
Batching
def read_my_file_format(filename_queue):
reader = tf.SomeReader() # tf.TFRecordReader()
key, record_string = reader.read(filename_queue) # queue
example, label = some_decoder(record_string) # parse example opt
processed_example = some_processing(example) # example processing
return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
# put TFRecord into queue, num_epochs is number of epochs and need to be local_initialized, if not specified, it will consistent read until the queue is outOfRange.
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
在tensorflow计算图未开始时,实际上上述过程只是配置了队列读取的相关参数和读取方式,队列中还没有任何数据,结合上一步骤的函数定义,需要用下述方式进行调用:
import tensorflow as tf
def run_training():
with tf.Graph().as_default(), tf.Session() as sess:
datas,labels = input_pipeline("example.tfrecords",32)
c = config()
initializer = tf.random_uniform_initializer(-1*c.init_scale,1*c.init_scale)
with tf.variable_scope("model",initializer=initializer):
model = ModelFunction(config=c,data=datas,label=labels)
fetches = [model.train_op,model.accuracy,model.loss]
#init
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
while not coord.should_stop():
_,accuracy,loss= sess.run(fetches)
except tf.errors.OutOfRangeError:
print("done training")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def main():
run_training()
if __name__=='__main__':
main()
这个例子中并没有对输入文件序列再次重新入队操作,即只是实现了一开始流图的第一阶段,然后就进入解析,构建batch的阶段。即是单个Reader,多个样本的读取方式(train.batch相同),如果使用train.batch_join, train.batch_shuffle_join
即是多个Reader,多个样本的读取方式。单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降[ref1], [ref2]。在COCO数据集处理的过程中,使用了单个Reader,单个Reader有四个线程处理(batch_join中Tensor List大小为4)。
这里需要注意的是,一定要全局初始化和局部初始化(tf.train.string_input_producer
中的num_epoch
是局部参数),但是有时候使用 tf.group()
的方式初始化可能不成功,可以单独分两次进行初始化,参见。
队列的开启使用的是tf.train.start_queue_runners
,具体有关Coordinator
的使用。
在使用tf.contrib.slim.learning.train
实现训练的时候,tf.contrib.slim.learning.train
函数中已经设置了coordinater, QueneRunner
等,并且也增加了初始化设置,因此只需要准备好batch数据即可。具体tf.contrib.slim.learning.train
的内容可以参见源码。
OutofRange()
: 未对队列读取抛出的异常进行处理OP_REQUIRES failed
数据处理过程中出现错误,包括维度不匹配 Dim error,文件读取问题 文件损坏,存在空行等UnicodeDecodeError
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。