一、什么是torch.repeat
torch.repeat 是 pyTorch 中的一个函数,它能将张量沿着指定的维度重复指定次数。重复张量的维度称为repeat dims,这个函数的参数是一个torch.Size 的元组,包含了每个维度重复的次数。举个例子,假如有一个形状为(3,4)的张量,维度为0沿着重复2次,维度为1沿着重复3次,那么该函数返回一个新的张量,形状为 (6,12)
二、如何使用torch.repeat
torch.repeat 有个需要注意的地方是它复制张量来产生新的张量,所以需要使用完整的内存。这意味着你需要在使用该函数前将所需要重复的张量复制到GPU或 CPU上。接下来让我们看一下如何使用这个函数。
# 导入torch
import torch
# 创建一个形状为(2,2)的张量
x = torch.Tensor([[1,2],[3,4]])
# 沿着第0维和第1维分别重复2次和3次
y = x.repeat(2, 3)
# 打印结果
print(y)
本代码中,我们首先导入了 pyTorch 库,并创建了一个形状为(2,2)的张量 x。接下来,我们使用 repeat 函数对 x 进行重复,其中第一个参数 2 表示第1维将被重复两次,第二个参数 3 表示第2维将被重复三次。最后,我们打印出了结果 y。 输出结果如下:
[[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]
[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]]
通过打印结果,我们可以看到张量 x 沿着第0维重复了两次,沿着第1维重复了三次。重复后的张量 y 的形状为 (4, 6), 并包含了重复后的值。
三、torch.repeat常见使用场景
torch.repeat 函数的常见应用场景分为以下两种:
1、将张量复制多次并拼接成一个大张量
假设有一个形状为(1,3)的张量 x,并将它重复3次并沿着第0维拼接成一个形状为(3,3)的张量 y。
# 创建一个形状为(1,3)的张量
x = torch.Tensor([[1,2,3]])
# 沿着第0维重复3次
y = x.repeat(3, 1)
# 打印结果
print(y)
输出结果如下:
[[1. 2. 3.]
[1. 2. 3.]
[1. 2. 3.]]
2、将张量进行扩维并重复
使用 repeat 函数可以将原始张量扩展为新的张量。举个例子,假如有一个形状为(1,3)的张量 x,并将它重复3次并沿着第0维拼接成一个形状为(3,3)的张量 y。
# 创建一个形状为(1,3)的张量
x = torch.Tensor([[1,2,3]])
# 在第0维上添加一个新的维度
xx = x.unsqueeze(0)
# 沿着第0维和第1维进行重复
y = xx.repeat(3, 1, 1)
# 打印结果
print(y)
输出结果如下:
[[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]]
该例子中,我们首先创建了一个形状为(1,3)的张量 x。接下来,使用 unsqueeze 函数在第0维上添加一个新的维度。最后,我们使用 repeat 函数沿着第0维和第1维进行重复并打印输出结果。