首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在K-折叠交叉验证中仅扩充训练集

在K-折叠交叉验证中仅扩充训练集
EN

Stack Overflow用户
提问于 2019-08-18 04:04:28
回答 1查看 1.2K关注 0票数 1

我正在尝试为一个不平衡的数据集(0类= 4000张图像,1类=大约250张图像)创建一个二进制CNN分类器,我想对其执行5次交叉验证。目前,我正在将训练集加载到应用转换/增强(?)的ImageLoader中。并将其加载到DataLoader中。然而,这导致我的训练拆分和验证拆分都包含了增强的数据。

我最初离线应用了转换(离线增强?)为了平衡我的数据集,但从这个线程(https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split)来看,似乎只增加训练集是理想的。我也更喜欢在单独增强的训练数据上训练我的模型,然后在5倍交叉验证中对非增强数据进行验证。

我的数据被组织为根/标签/图像,其中有2个标签文件夹(0和1),图像被排序到各自的标签中。

到目前为止我的代码

代码语言:javascript
运行
复制
total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])

//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)

model.train()
//Model train/eval works but may be overpredict 

我确信我在这段代码中做了一些次优或错误的事情,但我似乎找不到任何文档来专门增强交叉验证中的训练拆分!

任何帮助都将不胜感激!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-18 04:40:57

一种方法是实现一个包装dataset类,该类将转换应用于ImageFolder Dataset的输出。例如

代码语言:javascript
运行
复制
class WrapperDataset:
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, label = self.dataset[index]
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.dataset)

然后,您可以通过使用不同的转换包装较大的数据集,从而在代码中使用它。

代码语言:javascript
运行
复制
total_set = datasets.ImageFolder(ROOT)

# Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['train_transforms']),
        batch_size=32, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
        batch_size=32, sampler=valid_sampler)

    # train/validate now

我没有测试这段代码,因为我没有你的完整代码/模型,但是概念应该是清晰的。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57539567

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档