前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >盘一盘 Python 特别篇 23 - 爱因斯坦求和 einsum

盘一盘 Python 特别篇 23 - 爱因斯坦求和 einsum

作者头像
用户5753894
发布2022-12-18 16:08:05
2K0
发布2022-12-18 16:08:05
举报
文章被收录于专栏:王的机器

本文含 10083 字,18 图表截屏

建议阅读 52 分钟

本文是 Python 系列的特别篇的第二十三篇

0

引言

最近我以电子版的形式出了第二本书《Python 从入门到入迷》,然后定期更新书中的内容,最先想到的便是 einsum。

在 NumPy 包中,有一个函数叫做 einsum,它做的事情就是加总 (summation),但是是以爱因斯坦加总惯例 (Einstein's summation convention) 进行,因此得以此名。在深度学习框架 Tensorflow 和 PyTorch 也有这个函数,而且用法几乎一样,使用 einsum 首先需要从各自包中引用:

代码语言:javascript
复制
from numpy import einsum
from torch import einsum
from tensorflow import einsum

本文只拿 NumPy 包中的 einsum 来举例,并按照 what-how-why 主线来讲解,首先介绍什么 (what) 是 einsum,再展示怎么 (how) 用 einsum,最后来说明为什么 (why) 会有 einsum。相信这三部曲过后,我们可以把 einsum 整得明明白白的。

1

What is einsum?

1.1

爱因斯坦标记法

以下是一个矩阵相乘的具体例子,我们都知道结果矩阵第 2 行第 1 列的元素 4 是由“第一个矩阵第 2 行的元素”依次乘以“第二个矩阵第 1 列的元素”再加总,即 4 = 2*0 + 2*1 + 2*1。

矩阵相乘的通用形式如下,用字母代替数字得到 c21 = a21*b11 + a22*b21 + a23*b31。

写成通式就是

上式中的下指标可分成两类:

  • 出现两次的指标被称作哑指标 (dummy index),比如 j
  • 在单项式中只出现一次的指标被称作自由指标 (free index),比如 i 和 k

爱因斯坦对于式中出现的哑指标,约定默认对其进行求和。有了这个约定之后,上面表达式可简化成:

有了爱因斯坦约定得到的简写 (注意上面表达式的下标) ,用 einsum('ij,jk->ik',A,B)可以表达矩阵相乘,其中参数

  • 'ij,jk->ik' 是表示在爱因斯坦约定下的矩阵相乘字符串,箭头 -> 把字符串分成两部分,左侧部分表示输入矩阵,'ij' 标记 A 以及 'jk' 标记 B;右侧部分 'ik' 标记输出矩阵 C
  • A 和 B 是用于相乘的两个矩阵

下面用代码来看几个例子。

1.2

代码展示

首先创建矩阵 A 和 B。

代码语言:javascript
复制
A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])

B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])

用 einsum 函数来求矩阵相乘。

代码语言:javascript
复制
einsum('ij,jk->ik', A, B)
代码语言:javascript
复制
array([[ 2, 3, 1],
       [ 4, 6, 2],
       [10, 15, 5]])

用 np.matmul(A,B) 验证上面的语法确实做的是矩阵相乘。

代码语言:javascript
复制
np.matmul( A, B )
代码语言:javascript
复制
array([[ 2, 3, 1],
       [ 4, 6, 2],
       [10, 15, 5]])

自由指标和哑指标用任何字母字符都可以的,只要哑指标的位置写对即可,比如:

代码语言:javascript
复制
einsum('bF,FG->bG', A, B)
代码语言:javascript
复制
array([[ 2, 3, 1],
       [ 4, 6, 2],
       [10, 15, 5]])

用 'ij,jk->ik' 只不过字母 i,j 和 k 在数学中的下标表示中更常见。

爱因斯坦求和容易吧,你觉得你会了么?觉得会的话来看看下面的各种组合

  • 'ij,jk->ki'
  • 'ij,jk->ij'
  • 'ij,jk->ji'
  • 'ij,jk->jk'
  • 'ij,jk->kj'

是不是越看越困惑?

字符串 'ij,jk->ki' 得到的结果还好理解,就是矩阵乘完之后做个转置,因为箭头 -> 右边是 ki,正好和上例的 ik 反过来了。

代码语言:javascript
复制
einsum('ij,jk->ki', A, B)
代码语言:javascript
复制
array([[ 2, 4, 10],
       [ 3, 6, 15],
       [ 1, 2, 5]])

字符串 'ij,jk->ij' 和 'ij,jk->ij' 得到的结果就不好理解了,虽然我们看出来两种字符串得到的矩阵互为转置,但怎么得到的却不清楚,这个第三节会细讲。

代码语言:javascript
复制
einsum('ij,jk->ij', A, B)
einsum('ij,jk->ji', A, B)
代码语言:javascript
复制
array([[ 1, 2, 3],
       [ 2, 4, 6],
       [ 5, 10, 15]])
代码语言:javascript
复制
array([[ 1, 2, 5],
       [ 2, 4, 10],
       [ 3, 6, 15]])

同样,字符串 'ij,jk->jk' 和 'ij,jk->jk' 得到的结果更不好理解,两种字符串得到的矩阵互为转置,但怎么得到的却不清楚,这个第三节也会细讲。

代码语言:javascript
复制
einsum('ij,jk->jk', A, B)
einsum('ij,jk->kj', A, B)
代码语言:javascript
复制
array([[0, 8, 0],
       [8, 8, 0],
       [8, 8, 8]])
array([[0, 8, 8],
       [8, 8, 8],
       [0, 0, 8]])

如果你有对以上结果有困惑,那么请继续看下去,让我们来深挖 einsum 来总结其函数的一些通用规则。

2

How to use einsum?

当你学习一个新东西时,最好的方法是从最基础的部分开始,对于 einsum 这样基于数组的运算函数,我们就依次从 0 维 (标量),1 维 (向量),2 维 (矩阵) 到高维 (张量) 数组一步步来探索。

具体来说,einsum 函数的功能是

  1. 单数组不同轴上的元素求和
  2. 多数组相同轴上的元素相乘再求和

2.1

标量

0 维单数组

首先创建标量 arr0。

代码语言:javascript
复制
arr0 = 3

标量中没有轴的概念,按轴求和得到的结果就是它本身而已。

代码语言:javascript
复制
einsum("->", arr0)
代码语言:javascript
复制
3

注意字符串 "->" 可以看成 " -> ",箭头的左边和右边都是空字符,因为标量是 0 维度,如果用字母 i 来表示会报错。

代码语言:javascript
复制
einsum("i->", arr0)

如果在字符串中去掉箭头,得到的结果和上例是一样的,但是表示的含义有细微的区别。

代码语言:javascript
复制
einsum("", arr0)
代码语言:javascript
复制
3

上例的操作是对数组求和,本例的操作是返回该数组,只不过当数组为标量时,两者看起来是一样的 (对于非标量的数组就不是这样子了,后面读者会看到)。

规则总结:箭头 -> 表示求和。

0 维多数组

首先创建标量 A 和 B。

代码语言:javascript
复制
A = 3
B = 5

注意字符串 ",->" 可以看成 " , -> ",箭头的左边两个空字符代表用于相乘的两个标量,箭头右边的空字符代表结果。

代码语言:javascript
复制
einsum(",->", A, B)
代码语言:javascript
复制
15

去掉箭头也可以。

代码语言:javascript
复制
einsum(",", A, B)
代码语言:javascript
复制
15

去掉逗号会报错,因为后面跟着两个参数 A 和 B,因此需要逗号来分隔出来来描述 A 和 B 的两个字符串,即空字符串。

代码语言:javascript
复制
einsum(",", A, B)

根据为两个数组相乘设定的字符串,对于三个数组相乘,加一个逗号 ",,->" 就可以了。

代码语言:javascript
复制
C = 2
einsum(",,->", A, B, C)
代码语言:javascript
复制
30

规则总结:逗号 , 用来分隔数组,数组相乘在 einsum 函数的设置如下 (以 3 个数组举例):

三个颜色对应三个输入张量,分别是 2 维数组 (2 个红框),3 维数组 (3 个紫框) 和 2 维数组 (2 个蓝框),而输出张量是 2 维数组 (2 个绿框)。

从标量可以猜想出以上规则,但标量没有轴的概念,而且求和与其本身也看不来区别,因此我们需要用向量、矩阵和张量来验证或完善上面的规则。

2.2

向量

1 维单数组

首先创建向量 arr1。

代码语言:javascript
复制
arr1 = np.array([0, 1, 2])

向量只有一个轴,按轴求和得到的结果就是它包含所有元素的和。

代码语言:javascript
复制
einsum("i->", arr1)
代码语言:javascript
复制
3

注意字符串 "i->" 可以看成 "i-> ",箭头的左边字符 i 表示向量的轴 0 维度,

箭头右边的空字符表示求和得到标量。

代码语言:javascript
复制
einsum("i->", arr0)

如果在字符串中去掉箭头,得到的结果是该向量本身。

代码语言:javascript
复制
einsum("i", arr1)
代码语言:javascript
复制
array([0, 1, 2])

如果用字符串 "i->i",得到的结果也是该向量本身。

代码语言:javascript
复制
einsum("i->i", arr1)
代码语言:javascript
复制
array([0, 1, 2])

1 维多数组

首先创建向量 A 和 B。

代码语言:javascript
复制
A = np.array([1, 2, 3])
B = np.array([4, 5, 6])

字符串 "i,i->i" 指的数组 A 和 B 相同轴 (轴 0 i,i) 的元素依次相乘 (注意没有乘后相加) 得到的轴 0 维度 (i) 上的数组。

代码语言:javascript
复制
einsum("i,i->i", A, B)
代码语言:javascript
复制
array([ 4, 10, 18])

字符串 "i,i" 相当于 "i,i->",箭头右边是一个空字符,代表是标量,那么将上例得到的数组所有元素求和就可得到一个标量了。

代码语言:javascript
复制
einsum("i,i", A, B)
代码语言:javascript
复制
32

对于两个向量,字符串 "i,i" 代表它们的内积或点积操作。

代码语言:javascript
复制
np.inner(A, B)
np.dot(A, B)
代码语言:javascript
复制
32
32

接下来的字符串 "i,j" 有些难度了,下面两个语句的结果都是矩阵,两个向量怎么都能生成矩阵呢?难道是外积?

代码语言:javascript
复制
einsum("i,j", A, B)
einsum("i,j->ij", A, B)
代码语言:javascript
复制
array([[ 4, 5, 6],
       [ 8, 10, 12],
       [12, 15, 18]])
array([[ 4, 5, 6],
       [ 8, 10, 12],
       [12, 15, 18]])

从下面代码来看,确实是这样的。叉积的结果是矩阵是二维数组,而用于外积的两个向量是一维数组,这个升维操作其实是由 "i,j" 来实现的。用不同字母 i 和 j 就代表不同的维度,对应着结果矩阵中的轴 0 和轴 1 维度。

代码语言:javascript
复制
np.outer(A, B)
代码语言:javascript
复制
array([[ 4, 5, 6],
       [ 8, 10, 12],
       [12, 15, 18]])

现在知道外积的结果是个二维矩阵,那么当然可以沿着轴 0 ("i,j->i"),轴 1 ("i,j->j") 和对所有元素 ("i,j") 求和了,代码如下:

代码语言:javascript
复制
einsum("i,j->i", A, B) # 沿着轴 0 求和
einsum("i,j->j", A, B) # 沿着轴 1 求和
einsum("i,j->", A, B)  # 对所有元素求和
代码语言:javascript
复制
array([15, 30, 45]) # 沿着轴 0 求和
array([24, 30, 36]) # 沿着轴 1 求和
90                  # 对所有元素求和

规则总结:字符串 "i,j->x" 箭头 -> 右边的字符 x 来确定求和的方式,如果:

  • x 是 i,那么沿着轴 0 求和,因为字母 i 处在字符串 "i,j" 逗号前面
  • x 是 j,那么沿着轴 1 求和,因为字母 j 处在字符串 "i,j" 逗号后面
  • x 是 空字符,那么对所有元素求和,因为空字符对应着零维的标量

2.3

矩阵

2 维单数组

首先创建矩阵 arr2。

代码语言:javascript
复制
arr2 = np.array([[ 0,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8]])

用字符串 "ij" 和 "ji" 分别生成矩阵本身和其转置。

代码语言:javascript
复制
einsum("ij", arr2)
einsum("ji", arr2)
代码语言:javascript
复制
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
array([[0, 3, 6],
       [1, 4, 7],
       [2, 5, 8]])

字符串 "ii->i" 生成矩阵的对角线上的元素,即一维向量,和函数 np.diag(arr2) 等效。

代码语言:javascript
复制
einsum("ii->i", arr2)
np.diag(arr2)
代码语言:javascript
复制
array([0, 4, 8])
array([0, 4, 8])

字符串 "ii" 生成矩阵的对角线上的元素再求和,即零维标量,和函数 np.trace(arr2) 等效,求的是矩阵的迹。

代码语言:javascript
复制
einsum("ii", arr2)
np.trace(arr2)
代码语言:javascript
复制
12
12

巩固一下上节归纳出来的规则,对于二维矩阵,可以沿着轴 0 ("i,j->i"),轴 1 ("i,j->j") 和对所有元素 ("i,j") 求和了,代码如下:

代码语言:javascript
复制
einsum("ij->i", A, B) # 沿着轴 0 求和
einsum("ij->j", A, B) # 沿着轴 1 求和
einsum("ij->", A, B)  # 对所有元素求和
代码语言:javascript
复制
array([3, 12, 21]) # 沿着轴 0 求和
array([9, 12, 15]) # 沿着轴 1 求和
36                 # 对所有元素求和

注意:当求矩阵对角线时,返回的结果是矩阵的视图 (view),而不是复制 (copy)。以下面代码为例,当改变 c 中的元素,对应的 arr2 也会改变。

代码语言:javascript
复制
c = einsum("ii->i", arr2)
c[-1] = 10000
arr2
代码语言:javascript
复制
array([[ 0, 1, 2],
       [ 3, 4, 5],
       [ 6, 7, 10000]])

再对矩阵 arr2 (已改变了) 求迹结果已经变成 10004 了。

代码语言:javascript
复制
einsum("ii->", arr2)
代码语言:javascript
复制
10004

2 维多数组

首先创建矩阵 A 和 B。

代码语言:javascript
复制
A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])

B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])

用字符串 "ij,jk->ik" 来求矩阵相乘,和用 np.matmul(A,B) 等效。

代码语言:javascript
复制
einsum('ij,jk->ik', A, B)
代码语言:javascript
复制
array([[ 2, 3, 1],
       [ 4, 6, 2],
       [10, 15, 5]])

进一步理解一下上面的操作,就是把矩阵 A 轴 1 (列,ij 中的 j) 和矩阵 B 轴 0 (行,jk 中的 j) 每个元素相乘,然后沿着 j 代表的轴 (字符串只包含 ik) 求和。 而如上描述的操作刚好也是矩阵相乘的定义。

现在问题来了,那么在没有沿着 j 代表的轴求和之前的产出是什么呢?

代码语言:javascript
复制
einsum('ij,jk->ijk', A, B)
代码语言:javascript
复制
array([[[0, 1, 0],
        [1, 1, 0],
        [1, 1, 1]],

       [[0, 2, 0],
        [2, 2, 0],
        [2, 2, 2]],

       [[0, 5, 0],
        [5, 5, 0],
        [5, 5, 5]]])

结果是个三维数组,从 "ij,jk->ijk" 箭头右边的 ijk 也看得出来。由于结果比 A 和 B 高一维,它背后的操作实际上是

  • 将 A 在轴 2 上升一维 (从 ij 到 ijk)
  • 将 B 在轴 0 上升一维 (从 jk 到 ijk)

然后在元素层面上相乘。

先打印出升过维度的 A 和 B:

代码语言:javascript
复制
A[:,None]
B[None,:]
代码语言:javascript
复制
array([[[1, 1, 1]],
       [[2, 2, 2]],
       [[5, 5, 5]]])

array([[[0, 1, 0],
        [1, 1, 0],
        [1, 1, 1]]])

然后在元素层面上相乘,得到的结果和 einsum('ij,jk->ijk', A, B) 一致。

代码语言:javascript
复制
A[:,None] * B[None,:]
代码语言:javascript
复制
array([[[0, 1, 0],
        [1, 1, 0],
        [1, 1, 1]],

       [[0, 2, 0],
        [2, 2, 0],
        [2, 2, 2]],

       [[0, 5, 0],
        [5, 5, 0],
        [5, 5, 5]]])

有了这个三维数组,那么就好理解小节 1 里的各种组合结果了。

  1. 'ij,jk->ik'
  2. 'ij,jk->ki'
  3. 'ij,jk->ij'
  4. 'ij,jk->ji'
  5. 'ij,jk->jk'
  6. 'ij,jk->kj'

第一种 'ij,jk->ik' 就是在三维数组上沿着 j 轴求和,示意图如下。

那么第三种 'ij,jk->ij' 和第五种 'ij,jk->jk' 分别就是在三维数组上沿着 k 轴和 i 轴求和,对应着上面的三维数组图和下面的代码,我相信读者可以理解为什么结果是这样子了。

代码语言:javascript
复制
einsum('ij,jk->ij', A, B) # 沿着 k 轴求和,i 轴和 j 轴成了轴 0 和轴 1
代码语言:javascript
复制
array([[ 1, 2, 3],
       [ 2, 4, 6],
       [ 5, 10, 15]])
代码语言:javascript
复制
einsum('ij,jk->jk', A, B) # 沿着 i 轴求和,j 轴和 k 轴成了轴 0 和轴 1
代码语言:javascript
复制
array([[0, 8, 0],
       [8, 8, 0],
       [8, 8, 8]])

趁热打铁,来看看下面这个语句的结果,是一个四维数组!

代码语言:javascript
复制
einsum('ij,kl->ijkl', A, B)
代码语言:javascript
复制
array([[[[0, 1, 0],
         [1, 1, 0],
         [1, 1, 1]],

        [[0, 1, 0],
         [1, 1, 0],
         [1, 1, 1]],

        [[0, 1, 0],
         [1, 1, 0],
         [1, 1, 1]]],


       [[[0, 2, 0],
         [2, 2, 0],
         [2, 2, 2]],

        [[0, 2, 0],
         [2, 2, 0],
         [2, 2, 2]],

        [[0, 2, 0],
         [2, 2, 0],
         [2, 2, 2]]],


       [[[0, 5, 0],
         [5, 5, 0],
         [5, 5, 5]],

        [[0, 5, 0],
         [5, 5, 0],
         [5, 5, 5]],

        [[0, 5, 0],
         [5, 5, 0],
         [5, 5, 5]]]])

相信读者已经很难可视化该过程了,我们来捋捋,从 "ij,kl->ijkl" 箭头右边的 ijkl 看得出来结果是四维数组。由于结果比 A 和 B 高两维,它背后的操作实际上是

  • 将 A 在轴 2-3 上升两维 (从 ij 到 ijkl)
  • 将 B 在轴 0-1 上升两维 (从 kl 到 ijkl)

然后在元素层面上相乘。根据结果也可以把 "ij,kl->ijkl" 理解成 A 的每一个元素乘以 B。

用下面的代码得到的结果和 einsum('ij,kl->ijkl', A, B) 一致。

代码语言:javascript
复制
A[:,:,None,None] * B[None,None,:,:]

规则总结:在字符串"ij,jk->ik" 中

  • 箭头 -> 左边的重复指标 j 指的是该轴上的元素会相乘,这里有个隐含假设,那就是两个矩阵在轴 j 上的元素个数相等,不然会报错。
  • 箭头 -> 右边消失了指标 j 指的是沿着该轴求和。

2.4

张量

多维单数组

上节已经讲完了,从 'ijk' 到 'ij','jk' 和 'ik' 其实就是三维数组分别在轴 k、轴 i 和周 j 上做求和,因此把对应的轴“打掉”降了一维。

字符串 "ijk->" 对三维数组所有元素求和,得到标量 48。

代码语言:javascript
复制
C = A[:,None] * B[None,:]
einsum('ijk->', C)
代码语言:javascript
复制
48

多维多数组

首先创建三维张量 A 和 B。

代码语言:javascript
复制
A = np.arange(60.).reshape(3,4,5)
B = np.arange(24.).reshape(4,3,2)
代码语言:javascript
复制
array([[[ 0., 1., 2., 3., 4.],
        [ 5., 6., 7., 8., 9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]],

       [[20., 21., 22., 23., 24.],
        [25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34.],
        [35., 36., 37., 38., 39.]],

       [[40., 41., 42., 43., 44.],
        [45., 46., 47., 48., 49.],
        [50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59.]]])

array([[[ 0., 1.],
        [ 2., 3.],
        [ 4., 5.]],

       [[ 6., 7.],
        [ 8., 9.],
        [10., 11.]],

       [[12., 13.],
        [14., 15.],
        [16., 17.]],

       [[18., 19.],
        [20., 21.],
        [22., 23.]]])

字符串 "ijk,jil->kl" 将 A 切片轴 0-1 得到一个形状为 (3, 4) 的二维矩阵,比如 a;将 B 切片轴 0-1 得到一个形状为 (4, 3) 的二维矩阵,比如 b;然后用 a 乘以 b 的转置 ("ijk,jil->kl") 并对所有元素求和。这样的操作重复做最终填满形状为 (5, 2) 的二维矩阵 ("ijk,jil->kl") ,因为 A 沿轴 2 的元素个数是 5,B 沿轴 2 的元素个数是 2。

代码语言:javascript
复制
einsum('ijk,jil->kl', A, B)
代码语言:javascript
复制
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])

让我们用代码来明晰上面的文字解释。我们只关注上面数组 [0, 0] 位置的 4400 是怎么计算出来的。首先对 A 和 B 沿着轴 2 切片:

代码语言:javascript
复制
a = A[:,:,0]
b = B[:,:,0]
代码语言:javascript
复制
array([[ 0., 5., 10., 15.],
       [20., 25., 30., 35.],
       [40., 45., 50., 55.]])

array([[ 0., 2., 4.],
       [ 6., 8., 10.],
       [12., 14., 16.],
       [18., 20., 22.]])

然后用 a 乘以 b 的转置并对所有元素求和。

代码语言:javascript
复制
(a * b.T).sum()
代码语言:javascript
复制
4400

这样结果数组 [0, 0] 位置就知道怎么来的了,同理对 [0, 1], [1, 0], [1, 1], ..., [4, 0], [4, 1] 做上述同样的操作,就可以得到 einsum('ijk,jil->kl', A, B) 的结果了。

上述操作和 np.tensordot( A, B, axes=([0,1],[1,0]) ) 等效。

代码语言:javascript
复制
np.tensordot( A, B, axes=([0,1],[1,0]) )
代码语言:javascript
复制
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])

3

Why use einsum?

首先 einsum 一招鲜吃遍天,可以满足数组所有类型的运算,比如转置、内积、外积、对角线、迹、轴上求和,所有元素求和等。除此之外还有以下优点。

高效

代码语言:javascript
复制
A = np.array([0, 1, 2])
B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])
print(A.shape, B.shape)
代码语言:javascript
复制
(3,) (3, 4)

向量 A 不能直接乘以矩阵 B,不满足广播机制,因为 “A 形状最后一维元素个数 3 和 B 形状最后一维元素个数 4” 不匹配。

代码语言:javascript
复制
A * B

要让 A 和 B 可以相乘必须对 A 在轴 1 升一维,A[:,None]。

代码语言:javascript
复制
A[:,None] * B
代码语言:javascript
复制
array([[ 0, 0, 0, 0],
       [ 4, 5, 6, 7],
       [16, 18, 20, 22]])

假如我们想得到按轴 1 求和得到一个向量,可用代码

代码语言:javascript
复制
(A[:,None] * B).sum(axis=1)
代码语言:javascript
复制
array([ 0, 22, 76])

但用 einsum 能简约而轻松的得到以上结果。

代码语言:javascript
复制
einsum("i,ij->i", A, B)
代码语言:javascript
复制
array([ 0, 22, 76])

字符串 "i,ij->i" 由 -> 分成了两部分,它左边的 i,ij 对应两个输入,而右边的 i 对应输出。输出中没有下标 j,说明对两个输入沿着这个下标求和,而 i 所在的轴仍然保留。而 i 下标对应的维度的元素个数为 3,因此最终得到一个有 3 个元素的向量。

除了方便,einsum 也比传统做法高效。

代码语言:javascript
复制
%timeit (A[:,None] * B).sum(axis=1)
%timeit einsum("i,ij->i", A, B)
代码语言:javascript
复制
3.72 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.11 µs ± 155 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

偷懒

代码语言:javascript
复制
arr3 = np.ones((4,3,2))
代码语言:javascript
复制
array([[[1., 1.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]]])

如果待处理的张量不止三维,我们还可以“偷懒”地将多个彼此相连的维度格式字符串用省略号 (...) 代替,以表示剩下的所有维度。

代码语言:javascript
复制
einsum("ijk->jk", arr3)
einsum("i...->...", arr3)
代码语言:javascript
复制
array([[4., 4.],
       [4., 4.],
       [4., 4.]])
array([[4., 4.],
       [4., 4.],
       [4., 4.]])

简约

在注意力机制实现方式中,当考虑 Batch 维度时,公式如下:

用 einsum 函数可以非常简约的实现上面表达式:

代码语言:javascript
复制
from numpy.random import normal

Q = normal(size=(8,10))    # batch_size,query_features
K = normal(size=(8,10))    # batch_size,key_features
W = normal(size=(5,10,10)) # out_features,query_features,key_features
b = normal(size=(5,))      # out_features

A = einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
代码语言:javascript
复制
A.shape: (8, 5)

一个字符串 "bq,oqk,bk->bo" 就可以搞定,只要确保箭头左边 "bq,oqk,bk" 重复指标对应维度中的元素个数相等即可,在本例中:

  • 指标 q 对应维度中的元素个数为 10
  • 指标 k 对应维度中的元素个数为 10

最后 A 的形状为 (8, 5),结果合理,因为用字符串 "bo" 来描述 A,

  • 指标 b 对应维度中的元素个数为 8
  • 指标 o 对应维度中的元素个数为 5

4

总结

NumPy 包中的 einsum 可以替代如下常用的运算,

  • 矩阵求迹: trace
  • 求矩阵对角线: diag
  • 张量(沿轴)求和: sum
  • 张量转置: transopose
  • 矩阵乘法: dot
  • 张量乘法: tensordot
  • 向量内积: inner
  • 外积: outer

另外两表胜千言!

对于一维数组,即向量

对于二维数组,即矩阵

Stay Tuned!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-09-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 王的机器 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档