您的位置:

深入理解torch.mul

一、作用与基本用法

torch.mul是Pytorch中的一个重要函数,用于对两个张量逐元素相乘,返回一个新的张量。

torch.mul(input, other, out=None)

  • input:第一个相乘的张量
  • other:第二个相乘的张量
  • out :指定输出张量

若不指定out,返回的是逐元素相乘后的新张量,若指定out,则原地修改第一个张量,也就是将第二个张量逐元素相乘后结果赋值给第一个张量。

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 返回新的张量
c = torch.mul(a, b)
print(c)   # tensor([ 4, 10, 18])

# 原地修改第一个张量
torch.mul(a, b, out=a)
print(a)   # tensor([ 4, 10, 18])

二、特殊使用

1.向量点乘

向量点乘是指两个向量逐元素相乘然后相加的结果,可以用torch.mul实现。

例如,我们有两个向量a=[1,2,3]和b=[4,5,6],向量a与向量b的点积为1*4 + 2*5 + 3*6 = 32。

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

dot_product = torch.sum(torch.mul(a, b))

print(dot_product)  # tensor(32)

2.矩阵乘法

矩阵乘法是对两个矩阵进行操作的一种方式,可以使用torch.mul和torch.sum函数实现。

例如,我们有两个矩阵A=[[1,2,3],[4,5,6]]和B=[[7,8],[9,10],[11,12]],则矩阵AB为:[[1*7+2*9+3*11, 1*8+2*10+3*12], [4*7+5*9+6*11, 4*8+5*10+6*12]]。

import torch

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8], [9, 10], [11, 12]])

AB = torch.zeros((A.shape[0], B.shape[1]))

for i in range(A.shape[0]):
    for j in range(B.shape[1]):
        AB[i][j] = torch.sum(torch.mul(A[i], B[:,j]))

print(AB)
# tensor([[ 58,  64],
#         [139, 154]])

三、小结

torch.mul是Pytorch中一个十分常用的函数,可以对两个张量进行逐元素相乘操作得到新的张量,也可以在指定输出张量的情况下进行原地修改。此外,torch.mul还可以进行向量点乘和矩阵乘法等特殊用途。

当我们处理神经网络的深度学习中的线性变换过程时,经常会使用到torch.mul及相关操作函数,掌握好它们的使用方式也是我们提高深度学习技能的重要一步。