一、作用与基本用法
torch.mul
是 PyTorch 中的一个重要函数,用于对两个张量逐元素相乘,返回一个新的张量。
torch.mul(input, other, out=None)
input
:第一个相乘的张量other
:第二个相乘的张量out
:指定输出张量 如果不指定out
,返回的是逐元素相乘后的新张量;如果指定out
,则原地修改第一个张量,也就是将第二个张量逐元素相乘后的结果赋值给第一个张量。
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 返回新的张量
c = torch.mul(a, b)
print(c) # tensor([ 4, 10, 18])
# 原地修改第一个张量
torch.mul(a, b, out=a)
print(a) # tensor([ 4, 10, 18])
二、特殊使用
1. 向量点乘
向量点乘是指两个向量逐元素相乘然后相加的结果,可以用 torch.mul
实现。
例如,我们有两个向量 a=[1,2,3]
和 b=[4,5,6]
,向量 a
与向量 b
的点积为 1*4 + 2*5 + 3*6 = 32
。
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
dot_product = torch.sum(torch.mul(a, b))
print(dot_product) # tensor(32)
2. 矩阵乘法
矩阵乘法是对两个矩阵进行操作的一种方式,可以使用 torch.mul
和 torch.sum
函数实现。
例如,我们有两个矩阵 A=[[1,2,3],[4,5,6]]
和 B=[[7,8],[9,10],[11,12]]
,则矩阵 AB
为:
[[1*7+2*9+3*11, 1*8+2*10+3*12],
[4*7+5*9+6*11, 4*8+5*10+6*12]]
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8], [9, 10], [11, 12]])
AB = torch.zeros((A.shape[0], B.shape[1]))
for i in range(A.shape[0]):
for j in range(B.shape[1]):
AB[i][j] = torch.sum(torch.mul(A[i], B[:,j]))
print(AB)
# tensor([[ 58, 64],
# [139, 154]])
三、小结
torch.mul
是 PyTorch 中一个十分常用的函数,可以对两个张量进行逐元素相乘操作得到新的张量,也可以在指定输出张量的情况下进行原地修改。此外,torch.mul
还可以进行向量点乘和矩阵乘法等特殊用途。
当我们处理神经网络的深度学习中的线性变换过程时,经常会使用到 torch.mul
及相关操作函数,掌握好它们的使用方式也是我们提高深度学习技能的重要一步。