主要内容
“相信90%的小伙伴在使用python的矩阵运算函数时都遇到过问题,尤其是在神经网络编程时,matrix/array计算总是出现dimension不匹配错误。本文对常见的几个函数进行简单总结,供大家参考”
一、x*y
要求数组x与y的行数相同,y的列数为1或与x列数相同。如果二者列数相同,则y的每列与x的对应列做对位乘法;如果y的列数为1,则使用broadcast机制将y与x的每一列对应相乘。
二、np.multiply(x,y)
当使用broadcast机制时,与x*y功能相同。
三、x.dot(y)
矩阵的点乘,元素Aij=∑(Xik*Ykj)(k为X的列数/Y的行数)
四、np.matmul(x,y)
(1)如果x,y都是2维的,则按普通的矩阵点乘计算,要求x的列==y的行
(2)如果x,y至少有一个是N>2维,则使用broadcast机制
下图展示了多维矩阵的计算方式,此处没有使用broadcast机制,因为Y的shape[0]为2,与X的shape[0]相等。如果Y为2×1的数组,则需要将Y与X的shape[0]维度上每一个数组相乘(即broadcast)。
(3)如果x是一维的,则给x增加一个维度,在矩阵乘法后再去掉这个维度
此处x的元素个数必须是3,等于y的行数
(4)如果y是一维的,将y看做一个向量,与x的每一行向量进行内积运算,所以要求y的元素个数必须==x的列数。
matmul与dot类似但又有不同:
(1)不允许乘以标量,此时推荐使用x*y运算
(2)把矩阵也当成一个个元素一样进行broadcast
心得:在编程中遇到矩阵/数组运算时,通过输出矩阵/数组变量的shape来跟踪运算的正确与否是较好的方法。
领取专属 10元无门槛券
私享最新 技术干货