本文含 10083 字,18 图表截屏
建议阅读 52 分钟
本文是 Python 系列的特别篇的第二十三篇
0
引言
最近我以电子版的形式出了第二本书《Python 从入门到入迷》,然后定期更新书中的内容,最先想到的便是 einsum。
在 NumPy 包中,有一个函数叫做 einsum,它做的事情就是加总 (summation),但是是以爱因斯坦加总惯例 (Einstein's summation convention) 进行,因此得以此名。在深度学习框架 Tensorflow 和 PyTorch 也有这个函数,而且用法几乎一样,使用 einsum 首先需要从各自包中引用:
from numpy import einsum
from torch import einsum
from tensorflow import einsum
本文只拿 NumPy 包中的 einsum 来举例,并按照 what-how-why 主线来讲解,首先介绍什么 (what) 是 einsum,再展示怎么 (how) 用 einsum,最后来说明为什么 (why) 会有 einsum。相信这三部曲过后,我们可以把 einsum 整得明明白白的。
1
What is einsum?
1.1
爱因斯坦标记法
以下是一个矩阵相乘的具体例子,我们都知道结果矩阵第 2 行第 1 列的元素 4 是由“第一个矩阵第 2 行的元素”依次乘以“第二个矩阵第 1 列的元素”再加总,即 4 = 2*0 + 2*1 + 2*1。
矩阵相乘的通用形式如下,用字母代替数字得到 c21 = a21*b11 + a22*b21 + a23*b31。
写成通式就是
上式中的下指标可分成两类:
爱因斯坦对于式中出现的哑指标,约定默认对其进行求和。有了这个约定之后,上面表达式可简化成:
有了爱因斯坦约定得到的简写 (注意上面表达式的下标) ,用 einsum('ij,jk->ik',A,B)可以表达矩阵相乘,其中参数
下面用代码来看几个例子。
1.2
代码展示
首先创建矩阵 A 和 B。
A = np.array([[1, 1, 1],
[2, 2, 2],
[5, 5, 5]])
B = np.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1]])
用 einsum 函数来求矩阵相乘。
einsum('ij,jk->ik', A, B)
array([[ 2, 3, 1],
[ 4, 6, 2],
[10, 15, 5]])
用 np.matmul(A,B) 验证上面的语法确实做的是矩阵相乘。
np.matmul( A, B )
array([[ 2, 3, 1],
[ 4, 6, 2],
[10, 15, 5]])
自由指标和哑指标用任何字母字符都可以的,只要哑指标的位置写对即可,比如:
einsum('bF,FG->bG', A, B)
array([[ 2, 3, 1],
[ 4, 6, 2],
[10, 15, 5]])
用 'ij,jk->ik' 只不过字母 i,j 和 k 在数学中的下标表示中更常见。
爱因斯坦求和容易吧,你觉得你会了么?觉得会的话来看看下面的各种组合
是不是越看越困惑?
字符串 'ij,jk->ki' 得到的结果还好理解,就是矩阵乘完之后做个转置,因为箭头 -> 右边是 ki,正好和上例的 ik 反过来了。
einsum('ij,jk->ki', A, B)
array([[ 2, 4, 10],
[ 3, 6, 15],
[ 1, 2, 5]])
字符串 'ij,jk->ij' 和 'ij,jk->ij' 得到的结果就不好理解了,虽然我们看出来两种字符串得到的矩阵互为转置,但怎么得到的却不清楚,这个第三节会细讲。
einsum('ij,jk->ij', A, B)
einsum('ij,jk->ji', A, B)
array([[ 1, 2, 3],
[ 2, 4, 6],
[ 5, 10, 15]])
array([[ 1, 2, 5],
[ 2, 4, 10],
[ 3, 6, 15]])
同样,字符串 'ij,jk->jk' 和 'ij,jk->jk' 得到的结果更不好理解,两种字符串得到的矩阵互为转置,但怎么得到的却不清楚,这个第三节也会细讲。
einsum('ij,jk->jk', A, B)
einsum('ij,jk->kj', A, B)
array([[0, 8, 0],
[8, 8, 0],
[8, 8, 8]])
array([[0, 8, 8],
[8, 8, 8],
[0, 0, 8]])
如果你有对以上结果有困惑,那么请继续看下去,让我们来深挖 einsum 来总结其函数的一些通用规则。
2
How to use einsum?
当你学习一个新东西时,最好的方法是从最基础的部分开始,对于 einsum 这样基于数组的运算函数,我们就依次从 0 维 (标量),1 维 (向量),2 维 (矩阵) 到高维 (张量) 数组一步步来探索。
具体来说,einsum 函数的功能是
2.1
标量
0 维单数组
首先创建标量 arr0。
arr0 = 3
标量中没有轴的概念,按轴求和得到的结果就是它本身而已。
einsum("->", arr0)
3
注意字符串 "->" 可以看成 " -> ",箭头的左边和右边都是空字符,因为标量是 0 维度,如果用字母 i 来表示会报错。
einsum("i->", arr0)
如果在字符串中去掉箭头,得到的结果和上例是一样的,但是表示的含义有细微的区别。
einsum("", arr0)
3
上例的操作是对数组求和,本例的操作是返回该数组,只不过当数组为标量时,两者看起来是一样的 (对于非标量的数组就不是这样子了,后面读者会看到)。
规则总结:箭头 -> 表示求和。
0 维多数组
首先创建标量 A 和 B。
A = 3
B = 5
注意字符串 ",->" 可以看成 " , -> ",箭头的左边两个空字符代表用于相乘的两个标量,箭头右边的空字符代表结果。
einsum(",->", A, B)
15
去掉箭头也可以。
einsum(",", A, B)
15
去掉逗号会报错,因为后面跟着两个参数 A 和 B,因此需要逗号来分隔出来来描述 A 和 B 的两个字符串,即空字符串。
einsum(",", A, B)
根据为两个数组相乘设定的字符串,对于三个数组相乘,加一个逗号 ",,->" 就可以了。
C = 2
einsum(",,->", A, B, C)
30
规则总结:逗号 , 用来分隔数组,数组相乘在 einsum 函数的设置如下 (以 3 个数组举例):
三个颜色对应三个输入张量,分别是 2 维数组 (2 个红框),3 维数组 (3 个紫框) 和 2 维数组 (2 个蓝框),而输出张量是 2 维数组 (2 个绿框)。
从标量可以猜想出以上规则,但标量没有轴的概念,而且求和与其本身也看不来区别,因此我们需要用向量、矩阵和张量来验证或完善上面的规则。
2.2
向量
1 维单数组
首先创建向量 arr1。
arr1 = np.array([0, 1, 2])
向量只有一个轴,按轴求和得到的结果就是它包含所有元素的和。
einsum("i->", arr1)
3
注意字符串 "i->" 可以看成 "i-> ",箭头的左边字符 i 表示向量的轴 0 维度,
箭头右边的空字符表示求和得到标量。
einsum("i->", arr0)
如果在字符串中去掉箭头,得到的结果是该向量本身。
einsum("i", arr1)
array([0, 1, 2])
如果用字符串 "i->i",得到的结果也是该向量本身。
einsum("i->i", arr1)
array([0, 1, 2])
1 维多数组
首先创建向量 A 和 B。
A = np.array([1, 2, 3])
B = np.array([4, 5, 6])
字符串 "i,i->i" 指的数组 A 和 B 相同轴 (轴 0 i,i) 的元素依次相乘 (注意没有乘后相加) 得到的轴 0 维度 (i) 上的数组。
einsum("i,i->i", A, B)
array([ 4, 10, 18])
字符串 "i,i" 相当于 "i,i->",箭头右边是一个空字符,代表是标量,那么将上例得到的数组所有元素求和就可得到一个标量了。
einsum("i,i", A, B)
32
对于两个向量,字符串 "i,i" 代表它们的内积或点积操作。
np.inner(A, B)
np.dot(A, B)
32
32
接下来的字符串 "i,j" 有些难度了,下面两个语句的结果都是矩阵,两个向量怎么都能生成矩阵呢?难道是外积?
einsum("i,j", A, B)
einsum("i,j->ij", A, B)
array([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]])
array([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]])
从下面代码来看,确实是这样的。叉积的结果是矩阵是二维数组,而用于外积的两个向量是一维数组,这个升维操作其实是由 "i,j" 来实现的。用不同字母 i 和 j 就代表不同的维度,对应着结果矩阵中的轴 0 和轴 1 维度。
np.outer(A, B)
array([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]])
现在知道外积的结果是个二维矩阵,那么当然可以沿着轴 0 ("i,j->i"),轴 1 ("i,j->j") 和对所有元素 ("i,j") 求和了,代码如下:
einsum("i,j->i", A, B) # 沿着轴 0 求和
einsum("i,j->j", A, B) # 沿着轴 1 求和
einsum("i,j->", A, B) # 对所有元素求和
array([15, 30, 45]) # 沿着轴 0 求和
array([24, 30, 36]) # 沿着轴 1 求和
90 # 对所有元素求和
规则总结:字符串 "i,j->x" 箭头 -> 右边的字符 x 来确定求和的方式,如果:
2.3
矩阵
2 维单数组
首先创建矩阵 arr2。
arr2 = np.array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]])
用字符串 "ij" 和 "ji" 分别生成矩阵本身和其转置。
einsum("ij", arr2)
einsum("ji", arr2)
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
array([[0, 3, 6],
[1, 4, 7],
[2, 5, 8]])
字符串 "ii->i" 生成矩阵的对角线上的元素,即一维向量,和函数 np.diag(arr2) 等效。
einsum("ii->i", arr2)
np.diag(arr2)
array([0, 4, 8])
array([0, 4, 8])
字符串 "ii" 生成矩阵的对角线上的元素再求和,即零维标量,和函数 np.trace(arr2) 等效,求的是矩阵的迹。
einsum("ii", arr2)
np.trace(arr2)
12
12
巩固一下上节归纳出来的规则,对于二维矩阵,可以沿着轴 0 ("i,j->i"),轴 1 ("i,j->j") 和对所有元素 ("i,j") 求和了,代码如下:
einsum("ij->i", A, B) # 沿着轴 0 求和
einsum("ij->j", A, B) # 沿着轴 1 求和
einsum("ij->", A, B) # 对所有元素求和
array([3, 12, 21]) # 沿着轴 0 求和
array([9, 12, 15]) # 沿着轴 1 求和
36 # 对所有元素求和
注意:当求矩阵对角线时,返回的结果是矩阵的视图 (view),而不是复制 (copy)。以下面代码为例,当改变 c 中的元素,对应的 arr2 也会改变。
c = einsum("ii->i", arr2)
c[-1] = 10000
arr2
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 10000]])
再对矩阵 arr2 (已改变了) 求迹结果已经变成 10004 了。
einsum("ii->", arr2)
10004
2 维多数组
首先创建矩阵 A 和 B。
A = np.array([[1, 1, 1],
[2, 2, 2],
[5, 5, 5]])
B = np.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1]])
用字符串 "ij,jk->ik" 来求矩阵相乘,和用 np.matmul(A,B) 等效。
einsum('ij,jk->ik', A, B)
array([[ 2, 3, 1],
[ 4, 6, 2],
[10, 15, 5]])
进一步理解一下上面的操作,就是把矩阵 A 轴 1 (列,ij 中的 j) 和矩阵 B 轴 0 (行,jk 中的 j) 每个元素相乘,然后沿着 j 代表的轴 (字符串只包含 ik) 求和。 而如上描述的操作刚好也是矩阵相乘的定义。
现在问题来了,那么在没有沿着 j 代表的轴求和之前的产出是什么呢?
einsum('ij,jk->ijk', A, B)
array([[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]],
[[0, 2, 0],
[2, 2, 0],
[2, 2, 2]],
[[0, 5, 0],
[5, 5, 0],
[5, 5, 5]]])
结果是个三维数组,从 "ij,jk->ijk" 箭头右边的 ijk 也看得出来。由于结果比 A 和 B 高一维,它背后的操作实际上是
然后在元素层面上相乘。
先打印出升过维度的 A 和 B:
A[:,None]
B[None,:]
array([[[1, 1, 1]],
[[2, 2, 2]],
[[5, 5, 5]]])
array([[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]]])
然后在元素层面上相乘,得到的结果和 einsum('ij,jk->ijk', A, B) 一致。
A[:,None] * B[None,:]
array([[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]],
[[0, 2, 0],
[2, 2, 0],
[2, 2, 2]],
[[0, 5, 0],
[5, 5, 0],
[5, 5, 5]]])
有了这个三维数组,那么就好理解小节 1 里的各种组合结果了。
第一种 'ij,jk->ik' 就是在三维数组上沿着 j 轴求和,示意图如下。
那么第三种 'ij,jk->ij' 和第五种 'ij,jk->jk' 分别就是在三维数组上沿着 k 轴和 i 轴求和,对应着上面的三维数组图和下面的代码,我相信读者可以理解为什么结果是这样子了。
einsum('ij,jk->ij', A, B) # 沿着 k 轴求和,i 轴和 j 轴成了轴 0 和轴 1
array([[ 1, 2, 3],
[ 2, 4, 6],
[ 5, 10, 15]])
einsum('ij,jk->jk', A, B) # 沿着 i 轴求和,j 轴和 k 轴成了轴 0 和轴 1
array([[0, 8, 0],
[8, 8, 0],
[8, 8, 8]])
趁热打铁,来看看下面这个语句的结果,是一个四维数组!
einsum('ij,kl->ijkl', A, B)
array([[[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]],
[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]],
[[0, 1, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 2, 0],
[2, 2, 0],
[2, 2, 2]],
[[0, 2, 0],
[2, 2, 0],
[2, 2, 2]],
[[0, 2, 0],
[2, 2, 0],
[2, 2, 2]]],
[[[0, 5, 0],
[5, 5, 0],
[5, 5, 5]],
[[0, 5, 0],
[5, 5, 0],
[5, 5, 5]],
[[0, 5, 0],
[5, 5, 0],
[5, 5, 5]]]])
相信读者已经很难可视化该过程了,我们来捋捋,从 "ij,kl->ijkl" 箭头右边的 ijkl 看得出来结果是四维数组。由于结果比 A 和 B 高两维,它背后的操作实际上是
然后在元素层面上相乘。根据结果也可以把 "ij,kl->ijkl" 理解成 A 的每一个元素乘以 B。
用下面的代码得到的结果和 einsum('ij,kl->ijkl', A, B) 一致。
A[:,:,None,None] * B[None,None,:,:]
规则总结:在字符串"ij,jk->ik" 中
2.4
张量
多维单数组
上节已经讲完了,从 'ijk' 到 'ij','jk' 和 'ik' 其实就是三维数组分别在轴 k、轴 i 和周 j 上做求和,因此把对应的轴“打掉”降了一维。
字符串 "ijk->" 对三维数组所有元素求和,得到标量 48。
C = A[:,None] * B[None,:]
einsum('ijk->', C)
48
多维多数组
首先创建三维张量 A 和 B。
A = np.arange(60.).reshape(3,4,5)
B = np.arange(24.).reshape(4,3,2)
array([[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]],
[[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.]],
[[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59.]]])
array([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.]],
[[ 6., 7.],
[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.],
[16., 17.]],
[[18., 19.],
[20., 21.],
[22., 23.]]])
字符串 "ijk,jil->kl" 将 A 切片轴 0-1 得到一个形状为 (3, 4) 的二维矩阵,比如 a;将 B 切片轴 0-1 得到一个形状为 (4, 3) 的二维矩阵,比如 b;然后用 a 乘以 b 的转置 ("ijk,jil->kl") 并对所有元素求和。这样的操作重复做最终填满形状为 (5, 2) 的二维矩阵 ("ijk,jil->kl") ,因为 A 沿轴 2 的元素个数是 5,B 沿轴 2 的元素个数是 2。
einsum('ijk,jil->kl', A, B)
array([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])
让我们用代码来明晰上面的文字解释。我们只关注上面数组 [0, 0] 位置的 4400 是怎么计算出来的。首先对 A 和 B 沿着轴 2 切片:
a = A[:,:,0]
b = B[:,:,0]
array([[ 0., 5., 10., 15.],
[20., 25., 30., 35.],
[40., 45., 50., 55.]])
array([[ 0., 2., 4.],
[ 6., 8., 10.],
[12., 14., 16.],
[18., 20., 22.]])
然后用 a 乘以 b 的转置并对所有元素求和。
(a * b.T).sum()
4400
这样结果数组 [0, 0] 位置就知道怎么来的了,同理对 [0, 1], [1, 0], [1, 1], ..., [4, 0], [4, 1] 做上述同样的操作,就可以得到 einsum('ijk,jil->kl', A, B) 的结果了。
上述操作和 np.tensordot( A, B, axes=([0,1],[1,0]) ) 等效。
np.tensordot( A, B, axes=([0,1],[1,0]) )
array([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])
3
Why use einsum?
首先 einsum 一招鲜吃遍天,可以满足数组所有类型的运算,比如转置、内积、外积、对角线、迹、轴上求和,所有元素求和等。除此之外还有以下优点。
高效
A = np.array([0, 1, 2])
B = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
print(A.shape, B.shape)
(3,) (3, 4)
向量 A 不能直接乘以矩阵 B,不满足广播机制,因为 “A 形状最后一维元素个数 3 和 B 形状最后一维元素个数 4” 不匹配。
A * B
要让 A 和 B 可以相乘必须对 A 在轴 1 升一维,A[:,None]。
A[:,None] * B
array([[ 0, 0, 0, 0],
[ 4, 5, 6, 7],
[16, 18, 20, 22]])
假如我们想得到按轴 1 求和得到一个向量,可用代码
(A[:,None] * B).sum(axis=1)
array([ 0, 22, 76])
但用 einsum 能简约而轻松的得到以上结果。
einsum("i,ij->i", A, B)
array([ 0, 22, 76])
字符串 "i,ij->i" 由 -> 分成了两部分,它左边的 i,ij 对应两个输入,而右边的 i 对应输出。输出中没有下标 j,说明对两个输入沿着这个下标求和,而 i 所在的轴仍然保留。而 i 下标对应的维度的元素个数为 3,因此最终得到一个有 3 个元素的向量。
除了方便,einsum 也比传统做法高效。
%timeit (A[:,None] * B).sum(axis=1)
%timeit einsum("i,ij->i", A, B)
3.72 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.11 µs ± 155 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
偷懒
arr3 = np.ones((4,3,2))
array([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
如果待处理的张量不止三维,我们还可以“偷懒”地将多个彼此相连的维度格式字符串用省略号 (...) 代替,以表示剩下的所有维度。
einsum("ijk->jk", arr3)
einsum("i...->...", arr3)
array([[4., 4.],
[4., 4.],
[4., 4.]])
array([[4., 4.],
[4., 4.],
[4., 4.]])
简约
在注意力机制实现方式中,当考虑 Batch 维度时,公式如下:
用 einsum 函数可以非常简约的实现上面表达式:
from numpy.random import normal
Q = normal(size=(8,10)) # batch_size,query_features
K = normal(size=(8,10)) # batch_size,key_features
W = normal(size=(5,10,10)) # out_features,query_features,key_features
b = normal(size=(5,)) # out_features
A = einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
A.shape: (8, 5)
一个字符串 "bq,oqk,bk->bo" 就可以搞定,只要确保箭头左边 "bq,oqk,bk" 重复指标对应维度中的元素个数相等即可,在本例中:
最后 A 的形状为 (8, 5),结果合理,因为用字符串 "bo" 来描述 A,
4
总结
NumPy 包中的 einsum 可以替代如下常用的运算,
另外两表胜千言!
对于一维数组,即向量
对于二维数组,即矩阵
Stay Tuned!