在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作我们可以得到自己的TFRecord文件,并从其中解析出单个的Image和Label作为训练数据提供给网络模型使用,而在实际的网络训练过程中,往往不是使用单个数据提供给模型训练,而是使用一个数据集(mini-batch),mini-batch中的数据个数称为batch-size。mini-batch的思想能够有效的提高模型预测的准确率。大部分的内容和之前的操作是相同的,数据队列中存放的还是单个的数据和标签,只是在最后的部分将出队的数据组合成为batch使用,下面给出从原始数据到batch的整个流程:
可以看到,截止到生成单个数据队列操作,和之前并没有什么区别,关键之处在于最后batch的组合,一般来说单个数据队列的长度(capacity)和batch_size有关: capacity = min_dequeue+3*batch_size 我是这样理解第二个队列的:入队的数据就是解析出来的单个的数据,而出队的数据组合成了batch,一般来说入队数据和出队数组应该是相同的,但是在第二个队列中不是这样。
那么在TensorFlow中如何实现数据的组合呢,其实就是一个函数: tf.train.batch 或者 tf.train.shuffle_batch 这两个函数都会生成一个队列,入队的数据是单个的Image和Label,而出队的是一个batch,也已称之为一个样例(example)。他们唯一的区别是是否将数据顺序打乱。
本文以tf.train.batch为例,定义如下:
def batch(
tensors, //张量
batch_size, //个数
num_threads=1, //线程数
capacity=32,//队列长度
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None):
下面写一个代码测试一下,工程目录下有一个TFRecord数据集文件,该代码主要做以下工作,从TFRecord中读取单个数据,每四个数据组成一个batch,一共生成10个batch,将40张图片写入指定路径下,命名规则为batch?size?Label?,batch和size决定了是第几个组合中的第几个图,label决定数据的标签。
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
#路径
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)
#组合batch
batch_size = 4
mini_after_dequeue = 100
capacity = mini_after_dequeue+3*batch_size
example_batch,label_batch = tf.train.batch([image,label],batch_size = batch_size,capacity=capacity)
with tf.Session() as sess: #开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(sess = sess,coord=coord)
for i in range(10):#10个batch
example, l = sess.run([example_batch,label_batch])#取出一个batch
for j in range(batch_size):#每个batch内4张图
sigle_image = Image.fromarray(example[j], 'RGB')
sigle_label = l[j]
sigle_image.save(swd+'batch_'+str(i)+'_'+'size'+str(j)+'_'+'Label_'+str(sigle_label)+'.jpg')#存下图片
print(example, l)
coord.request_stop()
coord.join(threads)