前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >惊!你知道PyTorch浮点数上溢问题居然会导致这些结果?!

惊!你知道PyTorch浮点数上溢问题居然会导致这些结果?!

作者头像
不可言诉的深渊
发布2023-08-28 16:10:17
8660
发布2023-08-28 16:10:17
举报
文章被收录于专栏:Python机器学习算法说书人

当我们在使用 PyTorch 中的浮点数时,我们都知道它们并不能占满整个实数集 R。这主要是由于两个原因:精度和表示范围。对于计算机处理浮点数而言,精度不够的情况一般会选择截断,而超出表示范围的情况则通常会返回无穷大。然而,一旦 PyTorch 中的浮点数变成无穷大,将会出现非常奇怪的报错。因此,我们需要思考一下如何解决 PyTorch 中浮点数超出表示范围的问题。

浮点数的存储方式

浮点数是一种用于表示实数的数据类型,在计算机编程中广泛使用。浮点数在计算机中的表示通常使用 IEEE 754 标准。这个标准规定了浮点数的位数、指数和符号等信息。浮点数是由 3 个部分组成:符号(数符)、指数(阶码)和尾数。符号表示该数是正数还是负数,尾数则是实数的一个近似值,通常用二进制小数表示。而指数则是一个整数,用于标识该数的量级。在计算机中,浮点数的表示存储在一定长度的二进制数中。单精度浮点数使用 32 位二进制数进行存储,双精度浮点数则使用 64 位二进制数进行存储。在 PyTorch 中,不仅有上述提到的单精度浮点数和双精度浮点数,而且还有 2 种半精度浮点数,均使用 16 为二进制数存储。目前我们知道了不同浮点数一共占用多少个比特,但是浮点数由 3 个部分组成,对应数据类型和 3 个部分分别占用多少比特的情况见下表:

类型

符号

指数

尾数

偏置值

bfloat16

1

8

7

127

float16

1

5

10

15

float32

1

8

23

127

float64

1

11

52

1023

其中符号、指数以及尾数对应单位都是比特,偏置值就是一个数,没有单位。接着我来示范一下如何解释 bfloat16 那一行,后面 3 种类型照着我的说法来就行。浮点数类型是 bfloat16,一共占用 16 个比特,其中符号占用 1 个比特,指数占用 8 个比特,尾数占用 7 个比特,偏置值是 127。考虑到我们需要解决浮点数表示范围的问题,因此接下来就是如何基于上述内容计算出浮点数的表示范围。假设符号的二进制表示是 S,指数的二进制表示是 E,尾数的二进制表示是 M,偏置值的十进制表示是 B,那么对应十进制浮点数的表达式如下所示:

需要注意的是代入该表达式之前需要把 S、1.M 以及 E(都没有符号位)都转化为十进制数才能参与运算。既然如此,浮点数的表示范围需要就很容易进行计算了,我们只要令 M 和 E 对应的二进制表示全是 1,就得到了最大绝对值表示,记作 MAX_A,那么浮点数的表示范围就是这样一个区间:[-MAX_A, MAX_A]。不过不对,因为在计算机中一旦 E 对应的二进制全是 1 该浮点数就会被表示成无穷大。因此,我们需要让阶码退一步,让它对应的二进制表示的最低位为 0,这个时候我们才得到了最大绝对值表示。因此浮点数表示范围就是把这个新的最大绝对值表示去替换上述区间中的 MAX_A 就可以了。

案例分析

接下来我们结合 PyTorch 来分析一下浮点数超出表示范围的问题的一些案例,在这里需要注意的是我绝对不可能通过限制范围等这样的下下策来解决这一问题的,而是从数学公式的变换角度来解决这一问题。

平均值

我们首先来看第一个案例:平均值。虽然我知道 PyTorch 中有内置函数 mean 可以实现求平均值的操作以及平均值的公式就是累加再除以总数。但是,我们需要注意的是如果我们真的按照累加再除以总数的方法来计算平均值就比较容易让结果变成无穷大,在数据很大并且很多的情况下就容易出现这个问题,比如下面这个例子:

代码语言:javascript
复制
>>> import torch
>>> a = torch.tensor(10*[1e38], dtype=torch.float)
>>> a
tensor([1.0000e+38, 1.0000e+38, 1.0000e+38, 1.0000e+38, 1.0000e+38, 1.0000e+38,
        1.0000e+38, 1.0000e+38, 1.0000e+38, 1.0000e+38])
>>> a.sum()/a.numel()
tensor(inf)

从中我们可以发现,原本 a 中的所有元素都在表示范围内部,可是一旦通过这种先累加再除以总数的方法来计算平均值结果就是无穷大,这很明显不对。出现这个问题主要是因为先累加会产生比较大的中间结果并超出表示范围。要想解决这个问题很简单,换一下操作顺序就行,我不去先累加再除以总数,我先对 a 中的每个元素除以总数,然后再把新得到的一组数据累加起来,代码如下:

代码语言:javascript
复制
>>> (a/a.numel()).sum()
tensor(1.0000e+38)

我们可以发现结果恢复正常了,这主要是因为先除以总数把原本很大的数变小了,这样对这些变小的数进行累加就不会超出表示范围了。

softmax

然后我们来看第二个案例:softmax。虽然我知道 PyTorch 中有内置函数 softmax 可以实现这样的操作以及 softmax 的公式如图所示。

按照上述公式实现 softmax 非常简单,代码如下:

代码语言:javascript
复制
>>> def softmax(x):
...     return x.exp()/x.exp().sum()

但是,我们需要注意的是如果我们真的去使用这个代码会有着非常大的概率让分子分母都变成无穷大,比如下面这个例子:

代码语言:javascript
复制
>>> a = torch.tensor(3*[90], dtype=torch.float)
>>> a
tensor([90., 90., 90.])
>>> softmax(a)
tensor([nan, nan, nan])

从中我们可以发现,原本 a 中的所有元素都比较小,才只有 90,都没过百,上溢完全不可能。在这里出现上溢主要是因为 exp 方法,exp 方法的操作是对一个数求其指数函数(以自然对数 e 为底)的值。90 虽然很小,但是 e 的 90 次方可不是一个小数目,其值约等于 1.2204e+39,这已经超出了表示范围,导致结果变成无穷大。此外,这里的输出全都是 nan 是因为当分子达到无穷大的时候,而且分母>分子,分母必定也会达到无穷大,所以这就是一个无穷比无穷,而且计算机可不会像做高等数学的极限题那样知道无穷比无穷的极限有哪些方法可以求解,因此计算机对于这种求不出来的结果就只好返回 nan 了。要想解决这个问题很简单,我们可以尝试对分子分母同时除以一个很大的数来让结果不发生上溢。到目前为止,我们还差一个问题,这个很大的数到底是多少呢?我们先假设这个数用 M 表示,对 softmax 公式变形,如图所示。

此时此刻,我们可以发现,先对 x 中的每个元素减去 lnM 再去求 softmax 和直接对 x 求 softmax 是完全等价的!到目前为止,我们只要找到这个 M 就可以解决这个问题。寻找 M 的过程中需要注意两点:第一,lnM 尽可能的大;第二,lnM 不能被计算机认为是无穷大。第一点很容易理解,重点解释为什么也要满足第二点:考虑到当 lnM 变成了无穷大,分子分母都会变成 0(e 的负无穷次方),这样子问题就从无穷比无穷变成了 0/0。接下来就是去找这个 lnM,代码如下:

代码语言:javascript
复制
>>> torch.tensor(3e38, dtype=torch.float).log()
tensor(88.5968)
>>> torch.tensor(4e38, dtype=torch.float).log()
tensor(inf)

我们可以看到,M 位于 3e38 到 4e38 之间,接下来我们可能会尝试继续精确查找,其实这完全没有任何的意义,因为 ln(3e38)=88.5968 和 ln(4e38)=88.8845,结果没有差多少。因此,就算我们对 x 中的每个元素减去 lnM,数值也还是可能会很大(还是会出现上溢),毕竟 lnM 的允许的最大值都还是太小了(连 90 都不到)!这个时候比较容易想到的做法是把 lnM 看成一个整体,而不是像之前那样通过找 M 的方法来找 lnM,这样就算 x 中的元素值再大,一减去 lnM 就会变得很小,几乎不可能出现无穷比无穷。但是,如果 x 中的元素值很小,会出现这样一减就变成了负无穷,这样就出现了之前所说的 0/0 的情况。因此,我们不仅需要避免无穷比无穷的情况,而且还需要避免 0/0 的情况。换句话说,lnM 既不能非常大,也不能非常小。从上文中我们可以得知,我们希望当 x 中的元素非常大的时候 lnM 也非常大,当 x 中的元素非常小的时候 lnM 也非常小。这个时候我们可以令 lnM=f(x),其中 f 的输入是一个向量,输出是一个数。把一个向量变成一个数有很多方法,比如平均值、模长、最小值、最大值等。

首先我们看一下平均值能不能行,显然不可以,因为如果当 x=[-50, -50, 100] 的时候平均值为 0,减去 lnM 和原来一样,求 softmax 依旧出现上溢(e 的 100 次方)。

然后看一下模长行不行,显然更不可以,因为求模长本身的操作就有可能会越界,即使不考虑这个也不可以,因为当 x 中的元素值都是很小的负数,模长就会是一个很大的正数,那么 x-lnM 就是一个比 x 更小的正数,容易出现负下溢(0/0 的情况)。

接着看一下最小值行不行,显然不可以,因为如果当 x=[0, 0, 100] 的时候最小值是 0,减去 lnM 和原来一样,求 softmax 依旧出现上溢(e 的 100 次方)。

最后看一下最大值行不行,显然可以,因为这样做分子就会位于 0 到 1 之间(e 的负无穷次方e 的 0 次方之间),分母就会位于 1 到 K 之间(一个 e 的 0 次方+K-1 个 e 的负无穷次方K 个 e 的 0 次方之间)。

综上所述,我们得出 softmax 正确的公式变形,如图所示。

公式有了,代码实现就非常简单了,如下所示:

代码语言:javascript
复制
>>> def softmax(x):
...     x -= x.max()
...     return x.exp()/x.exp().sum()
...
>>> a = torch.tensor(3*[90], dtype=torch.float)
>>> a
tensor([90., 90., 90.])
>>> softmax(a)
tensor([0.3333, 0.3333, 0.3333])
>>> a = torch.tensor(3*[1e3], dtype=torch.float)
>>> a
tensor([1000., 1000., 1000.])
>>> softmax(a)
tensor([0.3333, 0.3333, 0.3333])
>>> a = torch.tensor(3*[-1e3], dtype=torch.float)
>>> a
tensor([-1000., -1000., -1000.])
>>> softmax(a)
tensor([0.3333, 0.3333, 0.3333])

我们可以发现结果都恢复正常了,softmax 的实现可以过了。

log_softmax

接下来我们来看第 3 个案例:log_softmax。虽然我知道 PyTorch 中有内置函数 log_softmax 可以实现这样的操作以及 log_softmax 的公式如图所示。

其中,这里的 log 以 e 为底。显然,这就是先求 softmax,再对 softmax 的结果取对数,为了避免上溢问题,很明显需要借助之前实现的 softmax,公式变形如图所示。

公式有了,代码实现就非常简单了,如下所示:

代码语言:javascript
复制
>>> def softmax(x):
...     x -= x.max()
...     return x.exp()/x.exp().sum()
...
>>>
>>> def log_softmax(x):
...     return softmax(x).log()

但是,我们需要注意的是如果我们真的去使用这个代码会有着一定的概率让结果出现负无穷,比如下面这个例子:

代码语言:javascript
复制
>>> a = torch.tensor([-60, -60, 60], dtype=torch.float)
>>> a
tensor([-60., -60.,  60.])
>>> log_softmax(a)
tensor([-inf, -inf, 0.])

从中我们可以发现,原本 a 中的所有元素都比较小,最小才只有 -60,最大才只有 60,都没过百,上溢完全不可能。这里出现上溢是因为 softmax 的结果中前两个元素都为 0,这导致了后面求对数的时候出现了 log0,也就是负无穷。显然,这样的公式变形并不是最好的,因此我们需要继续对公式变形,如图所示。

这个时候只有当 x 中的某一个元素非常小,max(x) 非常大,才有可能出现上溢问题让结果变成负无穷。

公式有了,代码实现就非常简单了,如下所示:

代码语言:javascript
复制
>>> def log_softmax(x):
...     x -= x.max()
...     return x-x.logsumexp(0)
...
>>> a = torch.tensor([-60, -60, 60], dtype=torch.float)
>>> a
tensor([-60., -60.,  60.])
>>> log_softmax(a)
tensor([-120., -120.,    0.])
>>> a = torch.tensor([-1e3, -1e3, 1e3], dtype=torch.float)
>>> a
tensor([-1000., -1000.,  1000.])
>>> log_softmax(a)
tensor([-2000., -2000.,     0.])

我们可以发现结果都恢复正常了,log_softmax 的实现可以过了。

logsumexp

最后我们来看第 4 个案例:logsumexp。虽然我知道 PyTorch 中有内置函数 logsumexp 可以实现这样的操作以及 logsumexp 的公式如图所示。

虽然我们在上文中已经用到了这个操作,但是有 2 点不同:第一,我是直接用的内置函数;第二,我不是对 x 求 logsumexp 的值,而是对 x 中的每个元素减去 x 的最大值得到的新向量求 logsumexp 的值。直接基于上述公式实现代码非常简单,如下所示:

代码语言:javascript
复制
>>> def logsumexp(x):
...     return x.exp().sum().log()

但是,我们需要注意的是如果我们真的去使用这个代码会有着一定的概率让结果出现无穷大,比如下面这个例子:

代码语言:javascript
复制
>>> a = torch.tensor(3*[90], dtype=torch.float)
>>> a
tensor([90., 90., 90.])
>>> logsumexp(a)
tensor(inf)

从中我们可以发现,原本 a 中的所有元素都比较小,才只有 90,都没过百,上溢完全不可能。在这里出现上溢的原因和之前 softmax 是一回事,都是 exp 方法导致的。很明显需要进行公式变形,如图所示。

这里我们首先提出一个 M,让 exp 方法的输入变小,进而避免出现上溢问题。然后我利用指数对数相关公式把寻找 M 的问题变成了寻找 logM 的问题。和之前 softmax 的分析一样,logM 非常小会导致第 2 项变成正无穷,logM 非常大会导致第 2 项变成负无穷。因此,我们希望当 x 中的元素非常大的时候 logM 也非常大,当 x 中的元素非常小的时候 logM 也非常小。这个时候我们可以令 logM=f(x),其中 f 的输入是一个向量,输出是一个数。把一个向量变成一个数有很多方法,比如平均值、模长、最小值、最大值等。这里 4 种方法的可行性分析过程和之前 softmax 时候的分析过程异曲同工,这里就不做分析了。经过分析之后还是最大值可行。所以,直接给出正确的公式变形,如图所示。

公式有了,代码实现就非常简单了,如下所示:

代码语言:javascript
复制
>>> def logsumexp(x):
...     max_x = x.max()
...     return max_x+(x-max_x).exp().sum().log()
...
>>> a = torch.tensor(3*[90], dtype=torch.float)
>>> a
tensor([90., 90., 90.])
>>> logsumexp(a)
tensor(91.0986)
>>> a = torch.tensor(3*[1e3], dtype=torch.float)
>>> a
tensor([1000., 1000., 1000.])
>>> logsumexp(a)
tensor(1001.0986)
>>> a = torch.tensor(3*[-1e3], dtype=torch.float)
>>> a
tensor([-1000., -1000., -1000.])
>>> logsumexp(a)
tensor(-998.9014)

我们可以发现结果都恢复正常了,logsumexp 的实现可以过了。

结论

最后一定需要记住的是,千万不要自以为是地认为弄懂了上面的几个案例就弄懂了浮点数上溢问题的解决方案!其中的公式变形绝对不可能是我讲一个你们跟着学一个!如果长此以往,你们解决这个问题的能力就永远在我能力的下方!一定要自己去尝试进行公式变形、代码实现以及代码调试!

bilibili 账号:新时代的运筹帷幄,喜欢的可以关注一下,看完视频不要忘记一键三连啊!

今天的文章有不懂的可以后台回复“加群”,备注:Python 机器学习算法说书人,不备注可是会被拒绝的哦~

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-06-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python机器学习算法说书人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档