PyTorch的DataLoader
是一个用于加载数据并批量处理的实用程序。它与Dataset
类一起工作,后者定义了如何访问数据集中的样本。DataLoader
负责将数据集分割成批次,并且可以并行加载数据以提高效率。
以下是一个简单的示例,展示如何使用DataLoader
与自定义的Dataset
交互:
import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集类
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 创建数据集实例
data = [torch.randn(3, 32, 32) for _ in range(100)] # 示例数据
dataset = CustomDataset(data)
# 创建DataLoader实例
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 使用DataLoader进行迭代
for batch in dataloader:
print(batch.shape) # 输出批次形状
问题:DataLoader加载数据速度慢。
原因:
解决方法:
DataLoader
的num_workers
参数,以使用多线程或多进程加载数据。dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
dataset = CustomDataset(data)
dataset = torch.utils.data.Subset(dataset, range(100)) # 示例子集
dataset = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
领取专属 10元无门槛券
手把手带您无忧上云