前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >自定义PyTorch中的Sampler

自定义PyTorch中的Sampler

作者头像
带萝卜
发布于 2020-10-23 06:51:18
发布于 2020-10-23 06:51:18
3.9K00
代码可运行
举报
运行总次数:0
代码可运行

本文使用 Zhihu On VSCode 创作并发布

在训练GAN的过程中,一次只训练一个类别据说有助于模型收敛,但是PyTorch里面没有预设这种数据加载方式,要这样训练的话,需要自己定义Sampler,即自定义数据采样方式。下面是自定义的方法:

首先,我们虚构一个Dataset类,用于测试。

这个类中的label标签是混乱的,无法通过控制index范围来实现单类别训练。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class Data(Dataset):
    def __init__(self):
        self.img = torch.cat([torch.ones(2, 2) for i in range(100)], dim=0)
        self.num_classes = 2
        self.label = torch.tensor(
            [random.randint(0, self.num_classes - 1) for i in range(100)]
        )

    def __getitem__(self, index):
        return self.img[index], self.label[index]

    def __len__(self):
        return len(self.label)

然后,自定义一个Sampler类,这个类的作用是生成一个index列表,可以理解为重排data中的index。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class CustomSampler(Sampler):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        indices = []
        for n in range(self.data.num_classes):
            index = torch.where(self.data.label == n)[0]
            indices.append(index)
        indices = torch.cat(indices, dim=0)
        return iter(indices)

    def __len__(self):
        return len(self.data)

定义好了之后可以封装成DataLoader并查看运行结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
d = Data()
s = CustomSampler(d)
dl = DataLoader(d, 8, sampler=s)
for img, label in dl:
    print(label)

结果

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1])

显然,这样的结果并不能让人满意,有一个batch中还是包含了两种不同类型的标签,为了达到目的,我们还需要再定义一个BatchSampler类

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class CustomBatchSampler:
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        i = 0
        sampler_list = list(self.sampler)
        for idx in sampler_list:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []

            if (
                i < len(sampler_list) - 1
                and self.sampler.data.label[idx]
                != self.sampler.data.label[sampler_list[i + 1]]
            ):
                if len(batch) > 0 and not self.drop_last:
                    yield batch
                    batch = []
                else:
                    batch = []
            i += 1
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

虽然PyTorch要求Sampler需要定义成一个迭代器,但是如果你自己定义BatchSampler的话,Sampler的形式可以自己定,就算写成一个普通的列表也没关系。

再次封装成DataLoader并查看运行结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
d = Data()
s = CustomSampler(d)
bs = CustomBatchSampler(s, 8, False)
dl = DataLoader(d, batch_sampler=bs)
for img, label in dl:
    print(label)

drop_last = False 的结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1])

drop_last = True 的结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])

以上就是自定义Sampler的步骤了。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
并发阻塞队列BlockingQueue解读
首先,最基本的来说, BlockingQueue 是一个先进先出的队列(Queue),为什么说是阻塞(Blocking)的呢?是因为 BlockingQueue 支持当获取队列元素但是队列为空时,会阻塞等待队列中有元素再返回;也支持添加元素时,如果队列已满,那么等到队列可以放入新元素时再放入。
大忽悠爱学习
2022/10/24
1.1K0
并发阻塞队列BlockingQueue解读
并发编程4:Java 阻塞队列源码分析(上)
上篇文章 并发编程3:线程池的使用与执行流程 中我们了解到,线程池中需要使用阻塞队列来保存待执行的任务。这篇文章我们来详细了解下 Java 中的阻塞队列究竟是什么。 什么是阻塞队列 阻塞队列其实就是生
张拭心 shixinzhang
2018/01/05
1.5K0
并发编程4:Java 阻塞队列源码分析(上)
BlockingQueue
BlockingQueue 是一个先进先出的队列(Queue), 并且当获取队列元素但是队列为空时,会阻塞等待队列中有元素再返回;也支持添加元素时,如果队列已满,那么等到队列可以放入新元素时再放入。
leobhao
2022/06/28
2890
Java 7 种阻塞队列详解
队列(Queue)是一种经常使用的集合。Queue 实际上是实现了一个先进先出(FIFO:First In First Out)的有序表。和 List、Set 一样都继承自 Collection。它和 List 的区别在于,List可以在任意位置添加和删除元素,而Queue 只有两个操作:
海星
2020/09/27
9.6K0
解读 Java 并发队列 BlockingQueue
原文出处:https://javadoop.com/post/java-concurrent-queue
Java
2018/10/23
6710
解读 Java 并发队列 BlockingQueue
JUC学习笔记(三)—同步阻塞队列
BlockingQueue 阻塞队列接口继承自Queue接口,BlockingQueue接口提供了3个添加元素方法:
Monica2333
2020/06/19
5630
【原创】Java并发编程系列32 | 阻塞队列(下)
阻塞队列在并发编程非常常用,被广泛使用在“生产者-消费者”问题中。本文是阻塞队列下篇。
java进阶架构师
2020/08/28
4470
源码剖析ThreadPoolExecutor线程池及阻塞队列
本文章对ThreadPoolExecutor线程池的底层源码进行分析,线程池如何起到了线程复用、又是如何进行维护我们的线程任务的呢?我们直接进入正题:
努力的小雨
2024/06/14
1881
解读 Java 并发队列 BlockingQueue
转自:https://javadoop.com/post/java-concurrent-queue
Java技术江湖
2019/09/25
6190
一文带你彻底掌握阻塞队列!
在之前的文章中,我们介绍了生产者和消费者模型的最基本实现思路,相信大家对它已经有一个初步的认识。
Java极客技术
2023/11/16
7920
一文带你彻底掌握阻塞队列!
阻塞队列与非阻塞队列
Java提供很多线程安全的容器,为开发人员在并发编程场景下使用,通常我们会更加关注业务实现,而不关心底层结构。但我们应该理解这些容器的原理和使用场景,以方便我们的开发和遇到问题的分析,并且有时候也能借鉴一下大神们的实现思想。
搬砖俱乐部
2019/06/15
3.3K0
聊聊 JDK 阻塞队列源码(ReentrantLock实现)
项目中用到了一个叫做 Disruptor 的队列,今天楼主并不是要介绍 Disruptor 而是想巩固一下基础扒一下 JDK 中的阻塞队列,听到队列相信大家对其并不陌生,在我们现实生活中队列随处可见,最经典的就是去银行办理业务等。 当然在计算机世界中,队列是属于一种数据结构,队列采用的FIFO(first in firstout),新元素(等待进入队列的元素)总是被插入到尾部,而读取的时候总是从头部开始读取。在计算中队列一般用来做排队(如线程池的等待排队,锁的等待排队),用来做解耦(生产者消费者模式),异步等等。
haifeiWu
2018/09/11
3470
10分钟从实现和使用场景聊聊并发包下的阻塞队列
上篇文章12分钟从Executor自顶向下彻底搞懂线程池中我们聊到线程池,而线程池中包含阻塞队列
菜菜的后端私房菜
2024/07/03
3550
多线程应用 - 阻塞队列LinkedBlockingDeque详解
在多线程阻塞队列的应用中上一篇已经讲述了ArrayBlockingQueue,在这一篇主要介绍思想与他差不多的另一个阻塞队列,基于链表的阻塞队列-LinkedBlockingDeque。基于链表的阻塞队列和基于数组的阻塞队列相同,内部都有一把可重入锁,对于该队列的写操作和读操作都会进行加锁,所以他们都是线程安全的,但是写操作和读操作都会占用锁资源所以在并发量大的情况下会降低性能。另外内部维护了读操作时和写操作时候的Condition,当队列在读取元素时,若发现队列中没有元素,会阻塞读操作,直到队列中有元素被可被读取时才会被唤醒。同理,写操作的Condition,当队列需要进行写入操作时,若发现队列容量满的时候,会阻塞写操作,直到队列中有元素被取出时才会被唤醒。
虞大大
2020/08/26
2.5K0
JUC并发—9.并发安全集合三
ArrayBlockingQueue是一个基于数组实现的阻塞队列。其构造方法可以指定:数组的长度、公平还是非公平、数组的初始集合。
东阳马生架构
2025/04/29
480
Java阻塞队列
?原文地址为https://www.cnblogs.com/haixiang/p/12354520.html,转载请注明出处! 什么是阻塞队列 原文地址为,转载请注明出处! 阻塞队列是一个支持阻塞的
海向
2020/02/25
5330
阻塞队列实现之PriorityBlockingQueue源码解析
PriorityBlockingQueue是一个支持优先级的无界阻塞队列,基于数组的二叉堆,其实就是线程安全的PriorityQueue。
烂猪皮
2023/09/04
1910
阻塞队列实现之PriorityBlockingQueue源码解析
多线程应用 - 阻塞队列ArrayBlockingQueue详解
ArrayBlockingQueue是一个阻塞式的先进先出队列。该结构具有以下三个特点:
虞大大
2020/08/26
1.5K0
回归Java基础:LinkedBlockingQueue阻塞队列解析
整理了阻塞队列LinkedBlockingQueue的学习笔记,希望对大家有帮助。有哪里不正确,欢迎指出,感谢。
捡田螺的小男孩
2020/04/15
4480
回归Java基础:LinkedBlockingQueue阻塞队列解析
Java并发队列原理剖析
LinkedBlockingQueue和ArrayBlockingQueue比较简单,不进行讲解了。下面只介绍PriorityBlockingQueue和DelayQueue。
用户4283147
2022/10/27
2670
Java并发队列原理剖析
推荐阅读
相关推荐
并发阻塞队列BlockingQueue解读
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档