Loading [MathJax]/extensions/TeX/AMSmath.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >稀疏Softmax(Sparse Softmax)

稀疏Softmax(Sparse Softmax)

作者头像
mathor
发布于 2021-07-19 07:59:29
发布于 2021-07-19 07:59:29
1.9K00
代码可运行
举报
文章被收录于专栏:mathormathor
运行总次数:0
代码可运行

本文源自于SPACES:“抽取-生成”式长文本摘要(法研杯总结),原文其实是对一个比赛的总结,里面提到了很多Trick,其中有一个叫做稀疏Softmax(Sparse Softmax)的东西吸引了我的注意,查阅了很多资料以后,汇总在此

Sparse Softmax的思想源于《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》《Sparse Sequence-to-Sequence Models》等文章。里边作者提出了将Softmax稀疏化的做法来增强其解释性乃至提升效果

不够稀疏的Softmax

前面提到Sparse Softmax本质上是将Softmax的结果稀疏化,那么为什么稀疏化之后会有效呢?我们认稀疏化可以避免Softmax过度学习的问题。假设已经成功分类,那么我们有(目标类别的分数最大),此时我们可以推导原始交叉熵的一个不等式:

$$ \begin{aligned} \log (\sum_{i=1}^n e^{s_i})-s_{\text{max}} &= \log (e^{s_t}+\sum_{i\neq t}e^{s_i})-s_{\text{max}}\\ &= \log (e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i})-\log (e^{s_{\text{max}}})\\ &= \log (\frac{e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i}}{e^{s_{\text{max}}}})\\ &= \log (1+ \sum_{i \neq t}e^{s_i - s_{\text{max}}})\\ & \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}}) \end{aligned}\tag{1} $$

假设当前交叉熵值为,那么有

解得

我们以为例,这时候,那么。也就是说,为了要loss降到0.69,那么最大的logit与最小的logit的差就必须大于,当比较大时,对于分类问题来说这是一个没有必要的过大的间隔,因为我们只希望目标类的logit比所有非目标类都要大一点就行,但是并不一定需要大那么多,因此常规的交叉熵容易过度学习从而导致过拟合

稀疏的Sparsemax

前面说了这么多关于Softmax的内容,那么Sparse Softmax或者说Sparsemax是如何做到稀疏化分布的呢?原文内容大家可以直接去看论文,写的非常复杂,这里我给出苏剑林大佬设计的一个更简单的版本

$$ \begin{array}{c|c|c} \hline & \text{Origin} & \text{Sparse} \\ \hline \text{Softmax} & p_i = \frac{e^{s_i}}{\sum\limits_{j=1}^{n} e^{s_j}} & p_i=\left\{\begin{aligned}&\frac{e^{s_i}}{\sum\limits_{j\in\Omega_k} e^{s_j}},\,i\in\Omega_k\\ &\quad 0,\,i\not\in\Omega_k\end{aligned}\right.\\ \hline \text{CrossEntropy} & \log\left(\sum\limits_{i=1}^n e^{s_i}\right) - s_t & \log\left(\sum\limits_{i\in\Omega_k} e^{s_i}\right) - s_t\\ \hline \end{array} $$

其中是将从大到小排列后前个元素的下标集合。说白了,苏剑林大佬提出的Sparse Softmax就是在计算概率的时候,只保留前个,后面的直接置零,是人为选择的超参数

代码

首先我根据苏剑林大佬的思路,给出一个简单版本的PyTorch代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
import torch.nn as nn

class Sparsemax(nn.Module):
    """Sparsemax loss"""

    def __init__(self, k_sparse=1):
        super(Sparsemax, self).__init__()
        self.k_sparse = k_sparse
        
    def forward(self, preds, labels):
        """
        Args:
            preds (torch.Tensor):  [batch_size, number_of_logits]
            labels (torch.Tensor): [batch_size] index, not ont-hot
        Returns:
            torch.Tensor
        """
        preds = preds.reshape(preds.size(0), -1) # [batch_size, -1]
        topk = preds.topk(self.k_sparse, dim=1)[0] # [batch_size, k_sparse]
        
        # log(sum(exp(topk)))
        pos_loss = torch.logsumexp(topk, dim=1)
        # s_t
        neg_loss = torch.gather(preds, 1, labels[:, None].expand(-1, preds.size(1)))[:, 0]
        
        return (pos_loss - neg_loss).sum()

再给出一个Github上找到的一个PyTorch原版代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
"""Sparsemax activation function.
Pytorch implementation of Sparsemax function from:
-- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
-- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
"""

import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Sparsemax(nn.Module):
    """Sparsemax function."""

    def __init__(self, dim=None):
        """Initialize sparsemax activation
        
        Args:
            dim (int, optional): The dimension over which to apply the sparsemax function.
        """
        super(Sparsemax, self).__init__()

        self.dim = -1 if dim is None else dim

    def forward(self, input):
        """Forward function.
        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size
        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor
        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        self.output = torch.max(torch.zeros_like(input), input - taus)

        # Reshape back to original shape
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)

        return output

    def backward(self, grad_output):
        """Backward function."""
        dim = 1

        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input
*补充

经过苏剑林大佬的许多实验发现,Sparse Softmax只适用于有预训练的场景,因为预训练模型已经训练得很充分了,因此finetune阶段要防止过拟合;但是如果从零训练一个模型,那么Sparse Softmax会造成性能下降,因为每次只有个类别被学习到,反而会存在学习不充分的情况(欠拟合)

References
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
WRF | WRF散度计算步骤及Python可视化实现
关注我获取更多学习资料,第一时间收到我的Python学习资料,也可获取我的联系方式沟通合作
用户11172986
2025/04/20
640
WRF | WRF散度计算步骤及Python可视化实现
如何用wrfout计算水汽通量散度
本文旨在实现WRFOUT的单层水汽通量散度和整层水汽通量散度计算方法。WRF(Weather Research and Forecasting)模式是一种广泛应用于天气和气候预测研究的数值模式。水汽通量散度在天气和气候研究中具有重要作用。本项目将针对WRF模式的输出数据(WRFOUT)进行处理和分析,实现单层水汽通量散度和整层水汽通量散度的计算。
用户11172986
2024/06/20
5321
如何用wrfout计算水汽通量散度
时间剖面图?so easy
在本文中,我们将利用WRFOUT数据进行处理和分析,并生成直观明了的时间剖面图。你将能够清楚地看到水汽通量散度随着时间和高度的变化趋势,从而更好地理解大气中水汽的传播与运动机制
用户11172986
2024/06/20
2060
时间剖面图?so easy
WRFOUT 单层水汽通量散度与整层水汽通量散度实现 2.0
📢 版权声明:公益性质转载需联系作者本人获取授权。转载本文时,请务必文字注明“来自:和鲸社区:酷炫用户名”,并附带本项目超链接。
用户11172986
2024/11/29
1840
WRFOUT 单层水汽通量散度与整层水汽通量散度实现 2.0
WRFOUT 涡度平流和温度平流计算与可视化
涡度平流和温度平流是两种常见的气象诊断量,可以帮助我们更好地理解大气运动和热力学过程。 以下代码将计算上述气象诊断量并可视化。
用户11172986
2024/06/20
4730
WRFOUT 涡度平流和温度平流计算与可视化
基于WRFOUT计算相对涡度,绝对涡度,位涡并可视化
版本:python3.7 数据:wrfout模拟数据 核心代码:metpy.calc.vorticity
用户11172986
2024/06/20
6950
基于WRFOUT计算相对涡度,绝对涡度,位涡并可视化
计算整层水汽通量散度是先积分后散度还是先散度后积分?
由于可视化代码过长隐藏,可点击运行Fork查看 若没有成功加载可视化图,点击运行可以查看 ps:隐藏代码在【代码已被隐藏】所在行,点击所在行,可以看到该行的最右角,会出现个三角形,点击查看即可
用户11172986
2025/03/24
1740
计算整层水汽通量散度是先积分后散度还是先散度后积分?
一刻也没有为圣诞的结束而悲伤,下一刻赶来的是ERA5数据计算700hPa水汽通量散度
由于可视化代码过长隐藏,可点击运行Fork查看 若没有成功加载可视化图,点击运行可以查看 ps:隐藏代码在【代码已被隐藏】所在行,点击所在行,可以看到该行的最右角,会出现个三角形,点击查看即可
用户11172986
2024/12/27
2820
一刻也没有为圣诞的结束而悲伤,下一刻赶来的是ERA5数据计算700hPa水汽通量散度
ERA5水汽通量散度剖面计算与绘图
由于可视化代码过长隐藏,可点击运行Fork查看 若没有成功加载可视化图,点击运行可以查看 ps:隐藏代码在【代码已被隐藏】所在行,点击所在行,可以看到该行的最右角,会出现个三角形,点击查看即可
用户11172986
2025/01/14
2770
ERA5水汽通量散度剖面计算与绘图
大气视热源的python计算尝试
大气视热源是常用于表征大气热力作用的概念,本项目会尝试使用metpy库计算大气视热源并可视化,希望能给你们一些微小的帮助。
用户11172986
2024/06/20
3200
大气视热源的python计算尝试
一瞬又一瞬,累积起来便是一生 | ERA5数据计算垂直积分整层水汽通量散度
由于可视化代码过长隐藏,可点击运行Fork查看 若没有成功加载可视化图,点击运行可以查看 ps:隐藏代码在【代码已被隐藏】所在行,点击所在行,可以看到该行的最右角,会出现个三角形,点击查看即可
用户11172986
2024/12/30
4200
一瞬又一瞬,累积起来便是一生 | ERA5数据计算垂直积分整层水汽通量散度
开工!wrfout 计算台风准地转omega方程右侧项
在本项目中,我们将使用MetPy库来计算准地转Omega方程中涡度平流项和温度平流的拉普拉斯算子。根据Bluesetein(1992;Eq.5.6.11)提出的QG-Omega方程,我们将关注方程右侧的两个主要强迫项
用户11172986
2024/06/20
2470
开工!wrfout 计算台风准地转omega方程右侧项
罕见!WRF计算LWC与IWC及可视化
气象家园帖子公式参考:https://bbs.06climate.com/forum.php?mod=viewthread&tid=90527&highlight=lwc
用户11172986
2024/06/20
4980
罕见!WRF计算LWC与IWC及可视化
数据处理于可视化 | 湿位涡剖面分析
在暴雨发生前期,形成暴雨的基本条件逐渐形成甚至完全具备。通过对形成暴雨的基本条件即:水汽条件、不稳定能量条件、上升运动条件等诊断分析,有助于判断暴雨发生的可能性。形成暴雨的主要物理条件有两个:内在因素是潮湿空气的潜在不稳定,而以充足的水汽表现为其主要方面,简称热力条件;外部因素是促使这种潜在不稳定得到充分释放的强迫抬升运动,而又以流场的配置为其主要方面,简称动力条件。有的把其分为三个条件,即把热力条件分为水汽和潜在位势不稳定两个条件。
郭好奇同学
2021/08/26
2.5K1
数据处理于可视化 | 湿位涡剖面分析
wrf-python 详解之如何使用
近几年,python在气象领域的发展也越来越快,同时出现了很多用于处理气象数据的python包。比如和NCL中的 WRF_ARWUser库类似的 wrf-python模块。
bugsuse
2020/04/20
21.2K0
在WRF中怎么算风能密度
2022 年,全国风能资源为正常略偏小年景。10 米高度年平均风速较近 10 年(2012 ~ 2021 年,下同)平均值偏小 0.82%,较2021 年偏小 0.96%。70 米高度年平均风速约 5.4m/s,年平均风功率密度约 193.1W/m2;100 米高度年平均风速约 5.7m/s,年平均风功率密度约 227.4W/m2。其中,湖北、江西、湖南、重庆较近 10 年平均值偏大,贵州、山西、宁夏、江苏、山东、河北、天津、内蒙古、西藏、河南、云南偏小,其他地区与近 10 年平均值接近。————《2022年中国风能太阳能资源年景公报》
用户11172986
2024/06/20
1660
在WRF中怎么算风能密度
Python可视化 | WRF模式模拟数据后处理(二)
导入模块 import numpy as np from netCDF4 import Dataset import matplotlib.pyplot as plt from matplotlib.cm import get_cmap from matplotlib.colors import from_levels_and_colors import cartopy.crs as crs import cartopy.feature as cfeature from cartopy.feature i
郭好奇同学
2021/08/26
4K0
Python可视化 | WRF模式模拟数据后处理(二)
WRFOUT 位温剖面和位温单格点高度图
WRF (Weather Research and Forecasting Model) 是一种广泛用于天气预报和气候模拟的数值大气模式。通过分析WRF模型的输出数据,我们可以获得各种天气变量的空间分布及其随时间的演变情况。
用户11172986
2024/06/20
4330
WRFOUT 位温剖面和位温单格点高度图
metpy绘制锋生与冷锋
https://www.heywhale.com/mw/project/65485a22d74b63fed5f03f49
用户11172986
2024/06/20
3450
metpy绘制锋生与冷锋
WRFOUT 绘制台站探空图与简单分析
实际应用中探空图可以分析所在区域的动热力特征,是预报员的好朋友 而在WRF应用中可以将其作为模式是否准确的检验工具 下面进行WRFOUT数据的探空图绘制
用户11172986
2024/06/20
1560
WRFOUT 绘制台站探空图与简单分析
推荐阅读
相关推荐
WRF | WRF散度计算步骤及Python可视化实现
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验