深入理解torch.mul

发布时间:2023-05-20

一、作用与基本用法

torch.mul 是 PyTorch 中的一个重要函数,用于对两个张量逐元素相乘,返回一个新的张量。

torch.mul(input, other, out=None)
  • input:第一个相乘的张量
  • other:第二个相乘的张量
  • out:指定输出张量 如果不指定 out,返回的是逐元素相乘后的新张量;如果指定 out,则原地修改第一个张量,也就是将第二个张量逐元素相乘后的结果赋值给第一个张量。
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 返回新的张量
c = torch.mul(a, b)
print(c)   # tensor([ 4, 10, 18])
# 原地修改第一个张量
torch.mul(a, b, out=a)
print(a)   # tensor([ 4, 10, 18])

二、特殊使用

1. 向量点乘

向量点乘是指两个向量逐元素相乘然后相加的结果,可以用 torch.mul 实现。 例如,我们有两个向量 a=[1,2,3]b=[4,5,6],向量 a 与向量 b 的点积为 1*4 + 2*5 + 3*6 = 32

import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
dot_product = torch.sum(torch.mul(a, b))
print(dot_product)  # tensor(32)

2. 矩阵乘法

矩阵乘法是对两个矩阵进行操作的一种方式,可以使用 torch.multorch.sum 函数实现。 例如,我们有两个矩阵 A=[[1,2,3],[4,5,6]]B=[[7,8],[9,10],[11,12]],则矩阵 AB 为:

[[1*7+2*9+3*11, 1*8+2*10+3*12],
 [4*7+5*9+6*11, 4*8+5*10+6*12]]
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8], [9, 10], [11, 12]])
AB = torch.zeros((A.shape[0], B.shape[1]))
for i in range(A.shape[0]):
    for j in range(B.shape[1]):
        AB[i][j] = torch.sum(torch.mul(A[i], B[:,j]))
print(AB)
# tensor([[ 58,  64],
#         [139, 154]])

三、小结

torch.mul 是 PyTorch 中一个十分常用的函数,可以对两个张量进行逐元素相乘操作得到新的张量,也可以在指定输出张量的情况下进行原地修改。此外,torch.mul 还可以进行向量点乘和矩阵乘法等特殊用途。 当我们处理神经网络的深度学习中的线性变换过程时,经常会使用到 torch.mul 及相关操作函数,掌握好它们的使用方式也是我们提高深度学习技能的重要一步。