发布
社区首页 >问答首页 >Pytorch:如何获取子集的所有数据和目标

Pytorch:如何获取子集的所有数据和目标
EN

Stack Overflow用户
提问于 2021-08-23 15:36:43
回答 1查看 592关注 0票数 1

我使用以下代码从特定文件夹中读取数据集,并将其划分为训练和测试子集。我可以使用列表理解来获得每个子集的所有数据和目标,但对于大数据来说,它非常慢。有没有其他快速的方法来做到这一点?

代码语言:javascript
代码运行次数:0
复制
def train_test_dataset(dataset, test_split=0.20):
    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=test_split, stratify=dataset.targets)
    datasets = {}
    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    return train_dataset, test_dataset


dataset = dset.ImageFolder("/path_to_folder", transform = transform)
    
train_set, test_set = train_test_dataset(dataset)

train_data = [data for data, _ in train_set]
train_labels = [label for _, label in train_set]

我已经使用DataLoader尝试过这种方法,它更好,但也需要一些时间:PyTorch Datasets: Converting entire Dataset to NumPy

谢谢。

EN

回答 1

Stack Overflow用户

发布于 2021-08-23 16:00:47

您提供的链接中的answer基本上违背了拥有数据加载器的目的:数据加载器的目的是逐块地将数据加载到内存中。这有一个明显的优点,那就是不必在给定时刻加载整个数据集。

您可以使用torch.utils.data.random_split函数从ImageFolder数据集中拆分数据:

代码语言:javascript
代码运行次数:0
复制
>>> def train_test_dataset(dataset, test_split=.2):
...    test_len = int(len(dataset)*test_split)
...    train_len = len(dataset) - test_len 
...    return random_split(dataset, [train_len, test_len])

然后,您可以将这些数据集插入单独的DataLoader中:

代码语言:javascript
代码运行次数:0
复制
>>> train_set, test_set = train_test_dataset(dataset)

>>> train_dl = DataLoader(train_set, batch_size=16, shuffle=True)
>>> test_dl  = DataLoader(train_set, batch_size=32 shuffle=False)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68895377

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档