前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow TFRecord数据集的生成与显示

TensorFlow TFRecord数据集的生成与显示

作者头像
chaibubble
发布2018-01-02 11:11:22
6.7K0
发布2018-01-02 11:11:22
举报
文章被收录于专栏:深度学习与计算机视觉

TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。 TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。 从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

将图片形式的数据生成单个TFRecord 在本地磁盘下建立一个路径用于存放图片:

路径下存放两个文件夹—NegSample和PosSample,分别存放着非车牌的图片和车牌图片,为了测试方便,每个文件夹下只分别存放14张。

利用下列代码将图片生成为一个TFRecord数据集:

代码语言:javascript
复制
import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np
#路径
cwd='F:\\testdata\\'
#类别
classes={'NegSample':1,
         'PosSample':2}
#tfrecords格式文件名
writer= tf.python_io.TFRecordWriter("mydata.tfrecords") 

for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name #每一个图片的地址

        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串

writer.close()

在工程路径下将生成一个名称为mydata.tfrecords的TFRCORDS类型的文件。

将图片形式的数据生成多个TFRecord 当图片数据量很大时也可以生成多个TFRecord文件,根据TensorFlow官方的建议,一个TFRecord文件最好包含1024个左右的图片,我们可以根据一个文件内的图片个数控制最后的文件个数。 举个例子,一共有四类,one - four为路径下的文件夹的名字,也就是类别,每个文件夹内存放600个图片,一共有2400张图片。 一个TFRecord文件中存放的图片个数最多为1200个,如果超过了就会写入第二个TFRecord文件中:

代码语言:javascript
复制
import os 
import tensorflow as tf 
from PIL import Image  

#图片路径
cwd = 'F:\\bubbledata_4\\testdata\\'
#文件路径
filepath = 'F:\\bubbledata_4\\testfile\\'
#存放图片个数
bestnum = 1000
#第几个图片
num = 0
#第几个TFRecord文件
recordfilenum = 0
#类别
classes=['one',
         'two',
         'three',
         'four']
#tfrecords格式文件名
ftrecordfilename = ("testndata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#类别和路径
for index,name in enumerate(classes):
    print(index)
    print(name)
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        num=num+1
        if num>bestnum:
          num = 1
          recordfilenum = recordfilenum + 1
          #tfrecords格式文件名
          ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
          writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
        '''
        print(num)
        print(recordfilenum)
        print(img_name)
        '''
        img_path = class_path+img_name #每一个图片的地址
        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(
             features=tf.train.Features(feature={
             #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        })) 
          #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

将单个TFRecord类型数据集显示为图片

上面提到了,TFRecord类型是一个包含了图片数据和标签的合集,那么当我们生成了一个TFRecord文件后如何查看图片数据和标签是否匹配? 可以将其转化为图片的形式再显示出来,并打印其在TFRecord中对应的标签,下面是一个例子,接上面生成单个TFRecord文件代码,在F:\testdata\show路径下显示解码后的图片,名称中包含标签。 其中: 1.tf.train.string_input_producer函数用于创建输入队列,队列中的内容为TFRecord文件中的元素。定义如下:

代码语言:javascript
复制
def string_input_producer(string_tensor,
                          num_epochs=None,
                          shuffle=True,
                          seed=None,
                          capacity=32,
                          shared_name=None,
                          name=None,
                          cancel_op=None):

每次调用文件读取函数(.read)时,该函数会先判断当前是否已有打开的文件可读,如果没有或者打开的文件已经读完,这个函数会从输入队列中出队一个文件并从这个文件中读取数据。 通过设置shuffle参数,tf.train.string_input_producer函数支持随机打乱文件列表中文件的出队顺序。当shuffle=true(默认)时,文件在加入队列之前会被打乱顺序,所以出队的顺序也是随机的。随机打乱文件顺序以及加入输入队列的过程运行在一个单独的县城上,这样不会影响获取文件的速度。其生成的输入队列可以被多个文件读取线程操作。 当一个输入队列中的所有文件都被处理完后,它会讲出实话时提供的文件列表中的文件全部重新加入队列。加入的轮数可以通过num_epochs参数设置,默认为None。 2.如果TFRecord文件不止一个时,也会用到tf.train.match_filenames_once函数来获取符合一定规则的文件列表。比如:

代码语言:javascript
复制
files = tf.train.match_filenames_once(mydata.tfrecords*)

函数将获取所有的工程路径下包含mydata.tfrecords名字的TFRecord文件,如mydata.tfrecords1,mydata.tfrecords2等。但是在下面的例子中只有一个TFRecord文件,所以直接使用了string_input_producer函数。

3.tf.parse_single_example解析器,可以将Example协议内存块(protocol buffer)解析为张量。

代码语言:javascript
复制
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #启动多线程
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(28):
        example, l = sess.run([image,label])#在会话中取出image和label
        img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
        img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

结果如下:

可以看到,车牌图片的Lable都为1,非车牌图片的Lable为0。通过上下两张图片可以看到,其出队顺序已经被打乱了。

将多个TFRecord类型数据集显示为图片 与读取多个文件相比,只需要加入两行代码而已:

代码语言:javascript
复制
data_path = 'F:\\bubbledata_4\\trainfile\\testdata.tfrecords*'
# 获取文件名列表
data_files = tf.gfile.Glob(data_path)       
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017-06-10 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档