在PyTorch中,你可以通过自定义数据集和数据加载器来实现批量大小为1的手动排序的MNIST数据集。以下是一个详细的示例,展示了如何实现这一目标。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
# 定义自定义数据集类
class SortedMNIST(Dataset):
def __init__(self, root, train=True, transform=None, download=False):
self.mnist = datasets.MNIST(root=root, train=train, transform=transform, download=download)
self.sorted_indices = self.sort_indices()
def sort_indices(self):
# 获取所有标签
labels = self.mnist.targets
# 获取排序后的索引
sorted_indices = torch.argsort(labels)
return sorted_indices
def __len__(self):
return len(self.mnist)
def __getitem__(self, idx):
sorted_idx = self.sorted_indices[idx]
image, label = self.mnist[sorted_idx]
return image, label
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 创建自定义数据集
sorted_mnist_train = SortedMNIST(root='./data', train=True, transform=transform, download=True)
sorted_mnist_test = SortedMNIST(root='./data', train=False, transform=transform, download=True)
# 创建数据加载器,批量大小为1
train_loader = DataLoader(sorted_mnist_train, batch_size=1, shuffle=False)
test_loader = DataLoader(sorted_mnist_test, batch_size=1, shuffle=False)
# 测试数据加载器
for batch_idx, (data, target) in enumerate(train_loader):
print(f'Batch {batch_idx}: Label {target.item()}')
if batch_idx >= 10: # 只打印前10个批次
break
torch
和torch.utils.data
用于数据处理。torchvision.datasets
用于加载MNIST数据集。torchvision.transforms
用于数据预处理和转换。SortedMNIST
类继承自torch.utils.data.Dataset
。__init__
方法中,加载MNIST数据集并调用sort_indices
方法获取排序后的索引。sort_indices
方法根据标签对数据进行排序,并返回排序后的索引。__len__
方法返回数据集的长度。__getitem__
方法根据排序后的索引返回图像和标签。transforms.Compose
定义一系列数据转换,包括将图像转换为张量和标准化。SortedMNIST
类创建训练和测试数据集。DataLoader
创建数据加载器,并设置批量大小为1。领取专属 10元无门槛券
手把手带您无忧上云