Tensordot详解:从多个角度深入理解

发布时间:2023-05-20

一、tensordot概述

tensordot 是一种 numpy 中的数学函数,它旨在实现高维张量的乘法操作。在实际深度学习的应用中,特别是卷积神经网络中,tensordot 是一项核心技术,因此学习如何使用它是至关重要的。 tensordot 最基本的使用形式为:np.tensordot(a, b, axes),其中 ab 都是具有多个轴的张量。在这个基本形式中,tensordotab 中的轴进行匹配,然后对它们进行乘法操作,最终返回一个新的张量 c

import numpy as np
a = np.random.rand(3, 4, 5)
b = np.random.rand(4, 5, 6)
c = np.tensordot(a, b, axes=([1, 2], [0, 1]))
print(c.shape)  # 输出(3, 6)

在此示例中,我们定义了两个张量 ab,分别是 shape 为 (3, 4, 5)(4, 5, 6) 的张量。我们对 a 的最后两个维度(4 和 5)和 b 的第一个和第二个维度(4 和 5)进行了匹配,然后执行了张量相乘,得到了一个新的张量 c,它的 shape 为 (3, 6)

二、理解 tensordot 的 axes 参数

tensordotaxes 参数用于指定张量 a 和张量 b 的维度匹配方式。在基本形式中,它采用了默认值,即 axes=2,它会从 ab 中的最后两个维度开始匹配两个张量,并输出其他维度的乘积。实际上,axes 接受一个元组 (x, y),其中 xy 分别表示 ab 的维度索引,表示我们要将 a 的第 x 个维度和 b 的第 y 个维度进行匹配。因此,当我们将 axes 设置为 ([1, 2], [0, 1]) 时,它将从 ab 中的第 1 和第 2 个维度开始匹配,并输出其他维度的乘积。 下面通过一个更高级的例子,来进一步理解 axes 参数的作用。

import numpy as np
a = np.random.rand(3, 4, 5)
b = np.random.rand(4, 5, 6)
c = np.tensordot(a, b, axes=([1], [0]))
print(c.shape)  # 输出(3, 6, 6)

在此示例中,我们设置了 axes=([1], [0]),这意味着我们要从 a 的第 1 个维度开始匹配,从 b 的第 0 个维度开始匹配。此时,a 的第 1 个维度大小为 4,b 的第 0 个维度的大小也为 4,因此,这种匹配方式是合法的。然后,我们执行 [a[:,i,:] * b[i,:,:] for i in range(4)] 操作,将这些张量相加,得到一个新的张量,它的 shape 为 (3, 6, 6)

三、tensordot 的高级操作

在深度学习中,tensordot 还有很多高级用法。

1. tensordot 的 broadcasting 行为

tensordot 类似于广播操作,它可以自动扩展输入张量的形状,以适应要执行的操作。因此,我们可以使用不同形状的张量来执行 tensordot 操作,根据 axes 参数的设置,可以自动调整张量的形状,以执行正确的操作。

import numpy as np
x = np.random.rand(2, 3)
y = np.random.rand(3, 4, 5)
z = np.tensordot(x, y, axes=(1, 0))
print(z.shape)  # 输出(2, 4, 5)

在本例中,我们定义了一个形状为 (2, 3) 的张量 x,和一个形状为 (3, 4, 5) 的张量 y。我们设置 axes=(1, 0),这意味着通过将 x 的第 1 个维度与 y 的第 0 个维度相匹配并相乘来计算 tensordotx 的第 1 个维度大小为 3,与 y 的第 0 个维度的大小相同,因此它们能正确匹配。我们得到的新张量的形状是 (2, 4, 5)

2. tensordot 的 reshape 操作

在某些情况下,我们需要将张量的维度进行重新排列,以使它们可以在 tensordot 操作中正确匹配。这个过程在 numpy 中的实现非常简单,我们可以使用 reshape 函数来轻松地重塑张量的形状。

import numpy as np
a = np.random.rand(3, 4, 5)
b = np.random.rand(4, 5, 6)
a = np.reshape(a, (3, 20))
b = np.reshape(b, (20, 6))
c = np.tensordot(a, b, axes=1)
print(c.shape)  # 输出(3, 6)

在此示例中,我们定义了两个张量 ab,分别是形状为 (3, 4, 5)(4, 5, 6) 的张量。然后,我们使用 reshape 函数将张量 ab 的形状分别改变为 (3, 20)(20, 6),这使它们可以正确匹配,进行 tensordot 操作。我们得到的新张量的形状是 (3, 6)

3. tensordot 的内积实现

tensordot 还可以用于计算内积。对于两个形状都为 (N,) 的张量,它们的内积可以通过 tensordot 来计算。

import numpy as np
x = np.random.rand(3)
y = np.random.rand(3)
ip = np.tensordot(x, y, axes=0)
print(ip)  # 输出单个实数

在此示例中,我们定义了两个形状为 (3,) 的张量 xy。我们将 axes 设置为 0,这意味着我们要计算两个张量的内积,即 sum(x[i] * y[i]),得到的结果是一个单个的实数。

四、总结

tensordot 是 numpy 中的一种高级操作,可用于计算张量的乘法。在深度学习中,tensordot 是卷积神经网络的核心技术之一。通过本文,我们深入理解了 numpy 中 tensordot 的基本用法和高级用法。可以根据具体的需求来选择合适的 axes 参数,轻松实现高维张量的乘法操作。