一、作用与基本用法
torch.mul是Pytorch中的一个重要函数,用于对两个张量逐元素相乘,返回一个新的张量。
torch.mul(input, other, out=None)
- input:第一个相乘的张量
- other:第二个相乘的张量
- out :指定输出张量
若不指定out,返回的是逐元素相乘后的新张量,若指定out,则原地修改第一个张量,也就是将第二个张量逐元素相乘后结果赋值给第一个张量。
import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # 返回新的张量 c = torch.mul(a, b) print(c) # tensor([ 4, 10, 18]) # 原地修改第一个张量 torch.mul(a, b, out=a) print(a) # tensor([ 4, 10, 18])
二、特殊使用
1.向量点乘
向量点乘是指两个向量逐元素相乘然后相加的结果,可以用torch.mul实现。
例如,我们有两个向量a=[1,2,3]和b=[4,5,6],向量a与向量b的点积为1*4 + 2*5 + 3*6 = 32。
import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) dot_product = torch.sum(torch.mul(a, b)) print(dot_product) # tensor(32)
2.矩阵乘法
矩阵乘法是对两个矩阵进行操作的一种方式,可以使用torch.mul和torch.sum函数实现。
例如,我们有两个矩阵A=[[1,2,3],[4,5,6]]和B=[[7,8],[9,10],[11,12]],则矩阵AB为:[[1*7+2*9+3*11, 1*8+2*10+3*12], [4*7+5*9+6*11, 4*8+5*10+6*12]]。
import torch A = torch.tensor([[1, 2, 3], [4, 5, 6]]) B = torch.tensor([[7, 8], [9, 10], [11, 12]]) AB = torch.zeros((A.shape[0], B.shape[1])) for i in range(A.shape[0]): for j in range(B.shape[1]): AB[i][j] = torch.sum(torch.mul(A[i], B[:,j])) print(AB) # tensor([[ 58, 64], # [139, 154]])
三、小结
torch.mul是Pytorch中一个十分常用的函数,可以对两个张量进行逐元素相乘操作得到新的张量,也可以在指定输出张量的情况下进行原地修改。此外,torch.mul还可以进行向量点乘和矩阵乘法等特殊用途。
当我们处理神经网络的深度学习中的线性变换过程时,经常会使用到torch.mul及相关操作函数,掌握好它们的使用方式也是我们提高深度学习技能的重要一步。