在没有用pytorch之前,读取数据一般时写一个load_data的函数,在里面导入数据,做一些数据预处理,这一部分就显得很烦索。对于深度学习来说,还得考虑batch的读取、GPU的使用、数据增强、数据乱序读取等等,所以需要有一个模块来集中解决这些事情,所以就有了data_loader的机制。
Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。
直接加载torch官方的数据集
分三步:
生成实例化对象
生成dataloader
从dataloader里读数据
PyTorch用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器:torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
数据加载器常用参数如下:dataset:加载数据的数据集;
batch_size:每个batch要加载多少样本(默认为1);
shuffle:是否对数据集进行打乱重新排列(默认为False,即不重新排列);
总结:torch的DataLoader主要是用来装载数据,就是给定已知的数据集,把数据集装载进DataLoaer,然后送入深度学习网络进行训练。