前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于tensorflow的图像处理(四) 数据集处理

基于tensorflow的图像处理(四) 数据集处理

作者头像
狼啸风云
修改2022-09-04 21:04:23
2.3K0
修改2022-09-04 21:04:23
举报
文章被收录于专栏:计算机视觉理论及其实现

除队列以外,tensorflow还提供了一套更高的数据处理框架。在新的框架中,每一个数据来源被抽象成一个“数据集”,开发者可以以数据集为基本对象,方便地进行batching、随机打乱(shuffle)等操作。

一、数据集的基本使用方法

在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord文件,一个文本文件,或者经过sharding的一系列文件,等等。由于训练数据集通常无法全部写入内存中,从数据中读取数据时需要使用一个迭代器(iterator)按顺序进行读取,这点与队列的dequeue()操作和Reader的read()操作相似。与队列相似,数据集也是计算图上的一个点。

下面先看一个简单的例子,这个例子从一个张量创建一个数据集,遍历这个数据集,并对每个输入输出y=x^2的值。

代码语言:javascript
复制
import tensorflow as tf

# 从一个数组创建数据集。
input_data = [1, 2, 3, 4, 5, 6]
dataset = tf.data.Dataset.from_tensor_slices(input_data)


# 定义一个迭代器用于遍历数据集。因为上面定义的数据集没有用placeholder
# 作为输入参数,所以这里可以使用最简单的one_shot_iterator。
iterator = dataset.make_one_shot_iterator()

# get_next() 返回代表一个输入数据的张量,类似于队列中dequeue()。

x = iterator.get_next()
y = x * x


with tf.Session() as sess:
   for i in range(len(input_data)):
      print(sess.run(y))

运行以上程序可以得到以下输出:

1 4 9 25 64

从以上例子可以看到,利用数据集读取数据有三个基本步骤。

1.定义数据集的构造方法

这个例子使用了tf.data.Dataset.from_tensor_slice(),表明数据集是从一个张量中构建的。如果数据集是从文件中构建的,则需要相应调用不同的构造方法。

2.定义遍历器

这个例子使用了最简单的one_shot_iterator来遍历数据集。

3.使用get_next()方法从遍历器中读取数据张量,作为计算图其他部分的输入

在真实项目中,训练数据通常是保存在硬盘文件上的。比如在自然语言处理的任务中,训练数据通常是以每行一条数据的形式存在文本文件中,这时可以用TextLineDataset来更方便地读取数据:

代码语言:javascript
复制
import tensorflow as tf

# 从文本创建数据集。假定每行文字是一个训练例子。注意这里可以提供多个文件。
input_files = ["/path/to/input_filel", "/path/to/input_file2"]
dataset = tf.data.TextLIneDataset(input_files)

# 定义迭代器用于遍历数据集。
iterator = dataset.make_one_shot_iterator()
# 这里get_next()返回一个字符串类型的张量,代表文件中的一行
x = iterator.get_next()
with tf.Session() as sess:
   for i in range(3):
      print(sess.run(x))

在图像相关任务中,输入数据通常以TFRecord形式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同, 每一个TFRecord都有自己不同的feature格式,因此在读取TFRecord时,需要提供一个parser函数来解析所读取的TFRecord的数据格式。

代码语言:javascript
复制
import tensorflow as tf

# 解析一个TFRecord的方法。record是从文件中读取的一个样例。
def parser(record):
    # 解析读入的一个样例
    features = tf.parse_single_example(
        record,
        features={
            'feat1': tf.FixedLenFeature([], tf.int64),
            'feat2': tf.FixedLenFeature([], tf.int64),
            }) 
    return features['feat1'],  features['feat2']


# 从TFRecord文件创建数据集
input_files = ["/path/to/input_file1", "/path/to/input_file2"]
dataset = tf.data.TFRecordDataset(inputfiles)


# map()函数表示对数据集中每一条数据进行调用响应的方法。使用TFRecordDataset读出的
# 是二进制的数据,这里需要通过map()来调用parser()对二进制数据进行解析。类似地,
# map()函数也可以用来完成其他的数据预处理工作。
dataset = dataset.map(parser)


# 定义遍历数据集的迭代器
iterator = dataset.make_one_shot_iterator()


# feat1, feat2是parser()返回的一维int64型张量,可以作为输入用于进一步的计算
feat1, feat2 = iterator.get_next()



with tf.Session() as sess:
     for i in range(10):
         f1, f2 = sess.run([feat1, feat2])

以上例子使用了最简单的one_shot_iterator来遍历数据集。在使用one_shot_iterator时,数据集的所有参数必须已经确定,因此one_shot_iterator不需要特别的初始化过程。如果需要用placeholder来初始化数据集,那就说明用到initializable_iterator。以下代码给出了用initializable_iterator来动态初始化数据集的例子。

代码语言:javascript
复制
import tensorflow as tf


# 解析一个TFRecord的方法。与上面的例子相同,不再重复。
def parser(record):

'''
# 从TFRecord文件创建数据集,具体文件路径是一个placeholder,稍后再提供具体路径。
input_flies = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)

# 定义遍历dataset的initializable_iterator。
iterator = dataset.make_initializable_iterator()
feat1, feat2 = iterator.get_next()



with tf.Session() as sess:
    # 首先初始化interator,并给出input_file的值
    sess.run(iterator.initializer,
             feed_dict={input_files:["/path/to/input_files", "/path/to/input_file2"]})
    
     # 遍历所有数据一个epoch。当遍历结束时,程序会抛出OutOfRangeError。
     while True:
        try:
           sess.run([feat1, feat2]) 
        expect tf.errors.OutOfRangeError:
           break

在上面的例子中,文件路径使用placeholder和feed_dict的方式传给数据集。使用这种方法,在实际项目中就不需要总是将参数写入计算图的定义,而可以使用程序参数的方式动态指定参数。

另外注意到,上面的例子中的循环体不是指定循环运行10次sess.run,而是使用while(True)try-expect的形式来将所有数据遍历一遍(即一个epoch)。这是因为在动态指定输入数据时,不同数据来源的数据量大小难以预知,而这个方法使我们不必提起知道数据量的精确大小。

以上介绍的两种iterator足以满足大多数项目的需求。除这两种方法外,tensorflow还提供了reinitializable_iterator和feedable_iterator两种更加灵活的迭代器。前者可以多次initialize用于遍历不同的数据来源,而后者可以用feed_dict的方式动态指定运行哪个iterator。

二、数据集的高层操作

下面介绍数据集框架提供的一些方便使用的高层API。前文介绍过map方法对TFRecord进行解析操作:

代码语言:javascript
复制
dataset = dataset.map(parser)

map是在数据集上进行操作的最常用的方法之一。在这里,map(parser)方法表示对数据集中的每一条数据调用的每一条数据调用参数中指定的parser方法。对每一条数据进行处理后,map将处理后的数据包装成一个新的数据集返回,map函数非常灵活,可以用于对数据的任何预处理操作。在队列框架下曾使用如下方法来对数据进行预处理:

代码语言:javascript
复制
distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None)

而在数据集框架中,可以通过map来对每一条数据调用prepeocess_for_train方法:

代码语言:javascript
复制
dataset = dataset.map(lambda x : preprocess_for_train(x, image_size, image_size, None))

在上面的代码中,lambda表达式的作用是将原来有4个参数的函数转化为只有1个参数的函数。preprocess_for_train函数的第一个参数decoded_image变成了lambda表达式中的x。这个参数就是原来函数中的参数decoded_image。preprocess_for_train函数中后3个参数都换成了具体的数值。注意这里的image_size是一个变量,有具体数值,该值需要在程序的上文中给出。

从表面上看,新的代码在长度上似乎并没有缩短,然而由于map方法返回一个新的数据集,可以直接继续调用其他高层操作。在队列框架中,预处理、shuffle、batch等操作有的在队列上进行,有的在图片张量上进行,整个处理流程在处理队列和张量的代码片段中来回切换。而在数据集操作中,所有操作都在数据集上进行,这样的代码结构将非常的干净、整洁。

队列框架下的tf.train.batch和tf.train.shuffle_batch方法、在数据集框架中,shuffle和batch操作由两个方法独立实现:

代码语言:javascript
复制
dataset = dataset.shuffle(buffer_size)  # 随机打乱顺序
dataset = dataset.batch(batch_size)     # 将数据组合成batch

其中shuffle方法的参数buffer_size等效于tf.train.shuffle_batch的min_after_dequeue参数。shuffle算法在内部使用一个缓冲区保存buffer_size条数据,每读入一条新数据时,从这个缓冲区中随机选择一条数据进行输出。缓冲区的大小越大,随机性能越好,但占用的内存也越多。

batch方法的参数batch_size代表要输出的每个batch由多少条数据组成。如果数据集中包含多个张量,那么batch操作将对每一个张量分开进行。举例而言,如果数据集中的每一个数据(即iterator.get_next()的返回值)是image、label两个张量,其中image的维度是[],batch_size是128,那么经过batch操作后的数据集的每一个输出将包含两个维度分别是[128, 300, 300]和[128]的张量。

repeat是另一个常用的操作方法。这个方法数据集中的数据复制多份,其中每一份数据被称为一个epoch。

代码语言:javascript
复制
dataset = dataset.repeat(N)   # 将数据集重复N份。

需要指出的是,如果数据集在repeat前已经进行了shuffle操作,输出的每个epoch中随机shuffle的结果并不会相同。例如,如果输入数据是[1, 2, 3],shuffle后输出的第一个epoch是[2, 1, 3],而第二个epoch则有可能是[3, 2, 1]。repeat和map、shuffle、batch等操作一样,都只是计算图中的一个计算节点。repeat只代表重复相同的处理过程,并不会记录前一epoch的处理结果。

除这些方法以外,数据集还提供了其他多种操作。例如,concatenate( )将两个数据集顺序连接起来,take(N)从数据集中读取前N项数据,skip(N)在数据集中跳过前N项数据,flap_map()从多个数据集中轮流读取数据,等等。这里不再一一介绍,有需要的读者可以查询tensorflow相关文档。

以下例子将这些方法组合起来,使用数据集实现数据输入流程,该例子从文件中读取原始数据,进行预处理、shuffle、batching等操作,并通过repeat方法训练多个epoch。不同的是,以下例子在训练数据集之外,还另外读取了数据集,并对测试集和数据集进行了略微不同的预处理。在训练时,调用preprocess_for_train 方法对图像进行随机反转等预处理操作;而在测试时,测试数据以原本的样子直接输入测试。

代码语言:javascript
复制
import tensorflow as tf

# 列举输入文件,训练和测试使用不同的数据。
train_files = tf.train.match_filenames_once("/path/to/train_file-*")
test_files = tf.train.match_filenames_once("/path/to/test_file-*")


# 定义parser方法从TFRecord中解析数据。这里假设image中存储的是图像的原始数据,
# label为该样例所对应的标签。height、width和channel给出了图片的维度。
def parser(record):
    features = tf.parse_single_example(
       record,
       features={
           'image': tf.FixedLenFeature([], tf.string),
           'label': tf.FixedLenFeature([], tf.int64),
           'height': tf.FixedLenFeature([], tf.int64),
           'width': tf.FixedLenFeature([], tf.int64),
           'channels': tf.FixedLenFeature([], tf.int64)
    })

    # 从原始数据中解析出像素矩阵,并根据图像尺寸还原图像。
    decoded_image = tf.decode_raw(features['image'], tf.unit8)
    decoded_image.set_shape([features['height'],features['width']
                             features['channels']])
    label = features['label']
    return decoded_image, label

image_size = 299         # 定义神经网络输入层图像的大小。
batch_size = 100         # 定义组合数据batch的大小。
shuffle_buffer = 10000   # 定义随机打乱数据时buffer的大小。


# 定义读取训练数据的数据集。
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser)


# 数据集依次进行预处理、shuffle和batching操作。
# preprocess_for_train为之前介绍的图像预处理程序,因为上一个map得到的数据集中提供了
# decoded_image和label两个结果,所以这个map需要提供一个有2个参数的函数来
# 处理数据。
# 在下面的代码中,lambda中的image代表的就是第一个map返回的
# decoded_image, label代表的就是第一个map返回的label。在这个lambda表达式中
# 我们首先将decoded_image在传入preprocess_for_train来进一步对图像数据进行预处理。
# 然后再将处理好的图像和label组成最终输出。

dataset = dataset.map(
    lambda image, label : (
       preprocess_for_train(image, image_size, image_size, None), label))

dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)


# 重复NUM_EPOCHS个epoch。在前述的中TRAINING_ROUNDS指定了训练的轮数,
# 而这里指定了整个数据集重复的次数,它也间接的确定了训练的轮数。
NUM_EPOCHS = 10
dataset = dataset.repeat(NUM_EPOCHS)


# 定义数据集迭代器。虽然定义数据集时没有直接使用palceholder来提供文件地址,但是
# tf.train.match_filename_once方法得到的结果和与palceholder的机制类似,
# 也需要初始化,所以这里使用的是initializable_iterator。
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()


# 定义神经网络的结构以及优化过程。
learning_rate = 0.01
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate)\.minimize(loss)



# 定义测试用的dataset。与训练时不同,测试数据的dataset不需要经过随机翻转等预处理
# 操作,也不需要打乱顺序和重复多个epoch。这里使用预训练数据相同的parser进行解析,
# 调整分辨率到网络输入层大小,然后直接进行batching操作。
test_dataset = tf.data.TFRecordDataset(test_files)
test_dataset = test_dataset.map(parser).map(
     lambda image, label : (
        tf.image.resize_images(images, [image_size, image_size], label))
test_dataset = test_dataset.batch(batch_size)


# 定义测试数据上的迭代器。
test_iterator = test_dataset.make_initializable_iterator()
test_image_batch, test_label_batch = test_iterator.get_next()


# 定义预测结果为logit值最大的分类。
test_logit = inference(test_image_batch)
predictions = tf.argmax(test_logit, axis=-1, output_type=tf.int32)


# 声明会话并运行神经网络的优化过程。
with tf.Session() as sess:
    # 初始化变量。
    sess.run((tf.global_variables_initializer(),
              tf.local_variable_initializer()))
    
    
    # 初始化训练数据的迭代器
    sess.run(iterator.initializer)
    
    
    # 循环进行训练,直到数据集完成输入,抛出OutOfRangeError错误。
    while True:
      try:
         sess.run(train_step)
      expect tf.errors.OutOfRangeError:
         break
    
    # 初始化测试数据的迭代器。
    sess.run(test_iterator.initializer)
    # 获取预测结果
    test_results = []
    test_labels  = []
    
    while True:
        try:
           pred, label = sess.run([predictions, test_label_batch])
           test_results.extend(pred)
           test_labels.extend(label)
        expect tf.errors.OutOfRangeError:
           break
 
# 计算准确率
correct = [float(y == y_) for (y , y_) in zip (test_results, test_labels)]
accuracy = sum(correct) / len(correct)
print("Test accuray is:", accuracy)

在MINST数据上运行以上程序,可以得到类似下面的结果 Test accuracy is: 0.9052

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年06月28日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、数据集的基本使用方法
  • 二、数据集的高层操作
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档