Overview
前两次推送解决了使用底层API完成用户自己的数据集构建和读取的问题,但是有一个明显的问题:太繁碎,细节太多
虽然主页菌自己一直在使用底层API,但是考虑到高层API的简洁性,还是有必要探索一下的。实际项目中诸位根据自己的喜好自行选择吧。
Tensorflow封装了一组API来处理数据的读入,它们都属于模块tf.data。这个模块中包含了5个类:4种Dataset和1个迭代器类型。
使用方法很简单:(1)构建Dataset (2)构造这个Dataset的迭代器 (3)操作迭代器读出数据。(听起来是不是很像一个标准的面向对象编程思路?)
本次教程使用最基本的tf.data.Dataset,使用的数据和推送 Tensrofow-6 相同,400张尺寸为180 x 180的灰度图像。部分截图如下:
构建Dataset
tf.data.Dataset主要提供了以下几种功能:
(1)构建数据集
(2)对数据集进行预处理
(3)将数据集打混(shuffle)和分批(mini-batch)
01
构建数据集
这一步其实就是实例化一个tf.data.Dataset对象,Dataset的原始数据来源自程序内存。这就涉及一个问题:如果预先把全部图像数据都装载进内存,势必是十分低效的,而且浪费大量资源,所以我们选择另一种方案:将全部图像数据的路径装载进内存并用其实例化Dataset,然后把实际从磁盘读入图像的操作放在“预处理”这个阶段。
首先载入图像路径和标签。由于这个demo只是随便一组图像,没有类别标签,所以我们随机生成400个标签,权当模拟。
然后就可以实例化Dataset对象了,Dataset中每一个“元素”是一个元组(图像路径,标签)
02
预处理
注意这里的“预处理”是广义的,可以是对Dataset中的“元素”进行任何操作。例如这里我们的“预处理”实际上是根据路径从磁盘中读取图片文件。
对Dataset进行预处理需要使用成员函数map(),传入参数是某个函数对象,map() 函数将会把这个函数作用在Dataset中每一个元素之上。
在我们这个例子中,有两个需要着重注意的细节:
(1) 由于这里我们的Dataset中每个元素指一个二元组,因此对应的预处理函数应该有两个输入参数,返回一个二元组。
(2)预处理函数只能包含tensorflow所提供的Tensor操作符,而这里我们难以避免的要使用opencv/PIL等python原生模块读取图像(前者处理的数据是tf.Tensor类型,后者处理的数据是numpy-ndarray类型,二者不兼容)。因此我们需要使用tf.py_func将python函数转换成tensorflow操作符。
首先定义python函数。注意输入两个参数,返回一个二元组。虽然label不需要任何操作原样返回,但是依然要这样写。
然后使用这个函数来“预处理”前面构建的Dataset对象。
注意这一句代码其实包含了两层操作:
(1)将python函数“包装”成tensorflow操作符。tf.py_func有三个输入参数:python函数,操作符的输入(实参),返回类型。
(2)将tensorflow操作符封装成lambda函数,将这个lambda函数对象传入map()
注意:作用于Dataset的是传入map()的函数对象(这里就是lambda函数),不是最初定义的python函数,也不是tensorflow操作符。
03
shuffle & batch
训练机器学习模型的时候,会将数据集遍历多轮(epoch),每一轮遍历都以mini-batch的形式送入模型(batch),数据遍历应该随机进行(shuffle)。在高层API中,这三个功能各自用一行代码就能搞定:
到这里我们就完成了数据集的构建,可以明显感觉到代码量减少了,而且只需要调用API和指定参数,没有诸如数据类型匹配、queue & dequeue之类的细节。
构建迭代器
说明:
(1)老规矩:next_element只是一个“符号”,需要用Session运行才能真正得到数据。
(2)每次用Session运行next_element,活得下一个mini-batch的数据,也就是获得一个尺寸为 N x 180 x 180 x 1 的Tensor。
主体程序——读取数据
说明:
当迭代器移动到尾部以后,再次运行next_element会产生OutOfRangeError,因此我们使用了try-catch语句。
其他代码和前两次推送没有差别:将读取的数据载入Tensorboard可视化。
运行效果
Tensorboard显示结果:
下期预告
现在你可能会问一个问题:在使用高层API的时候是不是就不需要和TFrecords文件打交道了?答案是No。
在很多图像处理任务中,我们不希望直接拿整张图作为训练数据,而是需要将图像切割成若干子图(patch),然后把全部的子图混合在一起作为数据集。这个时候如果像现在这样以整图为单元读入,就难以满足我们在子图层面充分打混数据的要求。
这种情况下,我们需要事先切割子图,然后依然将它们统一存放在TFrecords中等待读入。
tf.data接口也提供了从TFrecords读入的API,下次推送就来尝试一下。
https://saoyan.github.io/
We can do all things!
Simon Fraser University
Vancouver, Canada
领取专属 10元无门槛券
私享最新 技术干货