DataLoader简单介绍
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。通过使用DataLoader,我们可以方便地对数据进行相关操作,比如我们可以很方便地设置batch_size,对于每一个epoch是否随机打乱数据,是否使用多线程等等。
咱们先通过下图先来窥探DataLoader的基本处理流程。
1. 首先会将原始数据加载到DataLoader中去,如果需要shuffle的话,会对数据进行随机打乱操作,这样能够输入顺序对于数据的影响。
2. 再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。 Tips: 通过使用迭代器能够有效地降低内存的损耗,会在需要使用的时候才将数据加载到内存中去。
好了,知道了DataLoader的基本使用流程,下面开始正式进入我们的介绍。
使用Dataset来创建自己的数据类
当我们拿到数据之后,首先需要做的就是写一个属于自己的数据类。
我们通过继承torch.utils.data.Dataset这个类来构造。因为Dataset这个类比较简单,我们可以先来看看源码。
其中, __getitem__ 和 __len__ 这两个方法在我们每次自定义自己的类的时候是需要去复写的。
下面结合一个例子来进行介绍:
简单分析如下:
1. 继承Dataset来创建自己的数据类。将数据的下载,加载等,写入到这个类的初始化方法__init__中去,这样后面直接通过创建这个类即可获得数据并直接进行使用。
2. 通过复写 __getitem__ 方法可以通过索引来访问数据,能够同时返回数据和对应的标签(label)。
3. 通过复写 __len__ 方法来获取数据的个数。
使用DataLoader来控制数据的输入输出
结合上一节自己创建的Dataset,DataLoader的使用方式如下:
下面来对DataLoader中的常用参数进行介绍:
这样,我们就可以通过循环来迭代来高效地获取数据啦。