一、torch.cat()函数的基本用法
在PyTorch中,torch.cat()函数将多个张量沿着指定的维度进行拼接,并返回拼接后的新张量。它的基本语法如下所示:torch.cat(tensors, dim=0, out=None) -> Tensor
其中,tensors是要拼接的张量序列,dim是拼接的维度,out是可选的输出张量。接下来我们来看一些使用示例。
1、在维度0上拼接两个张量
import torch
x = torch.randn(2, 3)
y = torch.randn(3, 3)
z = torch.cat([x, y], dim=0)
print(z.shape) # output: torch.Size([5, 3])
在上面的例子中,我们首先定义了两个张量x和y,它们的形状分别为(2, 3)和(3, 3)。然后,我们使用torch.cat()函数在维度0上将它们拼接起来。由于x和y在维度0上长度之和为5,因此拼接后的张量形状为(5, 3)。
2、在维度1上拼接两个张量
import torch
x = torch.randn(2, 3)
y = torch.randn(2, 4)
z = torch.cat([x, y], dim=1)
print(z.shape) # output: torch.Size([2, 7])
在上面的例子中,我们定义了两个张量x和y,它们的形状分别为(2, 3)和(2, 4)。然后,我们使用torch.cat()函数在维度1上将它们拼接起来。由于x和y在维度1上长度之和为7,因此拼接后的张量形状为(2, 7)。
二、torch.cat()函数的高级用法
除了基本用法外,torch.cat()函数还有一些高级用法,包括指定输出张量、支持可变长度张量拼接、支持不同类型的张量拼接等。1、指定输出张量
在默认情况下,torch.cat()函数会返回一个新的张量。但是,我们也可以指定输出张量。例如:import torch
x = torch.randn(2, 3)
y = torch.randn(2, 4)
z = torch.zeros_like(x)
torch.cat([x, y], dim=1, out=z)
print(z.shape) # output: torch.Size([2, 7])
在上面的例子中,我们首先定义了两个张量x和y。然后,我们定义了一个与x形状相同的空张量z,并使用torch.cat()函数在维度1上将x和y拼接到z中,得到拼接后的张量z。
2、支持可变长度张量拼接
在实际应用中,我们可能遇到需要拼接的张量长度不一的情况。对于这种情况,PyTorch也提供了支持。例如:import torch
x = torch.randn(2, 3)
y = torch.randn(3, 4)
z = torch.randn(4, 2, 3)
w = torch.cat([x, y, z], dim=0)
print(w.shape) # output: torch.Size([9, 2, 3])
在上面的例子中,我们定义了三个张量x、y和z,它们的长度分别为2、3和4。然后,我们使用torch.cat()函数在维度0上将它们拼接起来。由于它们在维度0上长度之和为9,因此拼接后的张量形状为(9, 2, 3)。
3、支持不同类型的张量拼接
除了支持同一类型的张量拼接外,torch.cat()函数还支持拼接不同类型的张量。例如:import torch
x = torch.randn(2, 3)
y = torch.randn(2, 4).int()
z = torch.cat([x, y], dim=1)
print(z) # output: tensor([[ 0.2306, -0.9291, -1.0282, 0, 1, 0, 1], [ 1.3855, -0.1479, 1.3322, 0, 0, 1, 1]])
在上面的例子中,我们定义了两个张量x和y,它们的类型分别为float和int。然后,我们使用torch.cat()函数在维度1上将它们拼接起来。注意,由于y的类型为int,因此向拼接后的张量中填充时需要将它转换为float类型。
三、torch.cat()函数的注意点
虽然torch.cat()函数非常实用,但是在使用时需要注意一些细节。1、拼接维度必须存在
torch.cat()函数只能在输入张量共同拥有的维度上进行拼接。举个例子,如果我们想在两个张量的第2维上进行拼接,那么它们必须在第2维上具有相同的长度,否则会报错。例如:import torch
x = torch.randn(2, 3, 4)
y = torch.randn(2, 4, 5)
z = torch.cat([x, y], dim=1) # 报错!
在上面的例子中,我们想在张量x和y的第2维上进行拼接,但是它们在第2维上的长度不同,因此会报错。
2、torch.cat()函数不改变输入张量
torch.cat()函数返回的是一个新的张量,而不是对输入张量进行原地修改。如果要实现原地修改,可以使用inplace=True参数。例如:import torch
x = torch.randn(2, 3)
y = torch.randn(2, 4)
x = torch.cat([x, y], dim=1)
在上面的例子中,我们使用torch.cat()函数拼接x和y,得到新的张量x。要注意的是,这里我们将新的张量x赋值给了原来的x。如果不赋值,原来的张量x还是不变的。
3、torch.cat()函数不适合大型数据集
由于torch.cat()函数需要在内存中创建一个新的张量,因此在拼接大型数据集时可能会导致内存不足。如果遇到这种情况,可以考虑使用torch.utils.data.Dataset和torch.utils.data.ConcatDataset来处理数据集。四、torch.cat()函数的其他衍生函数
除了torch.cat()函数外,PyTorch还提供了一些其他的拼接函数,包括torch.stack()、torch.split()、torch.chunk()等。1、torch.stack()函数
torch.stack()函数用于在新的维度上堆叠输入张量。它的基本语法如下所示:torch.stack(tensors, dim=0, out=None) -> Tensor
其中,tensors是指要堆叠的输入张量,dim是堆叠的维度,out是可选的输出张量。例如:
import torch
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.stack([x, y], dim=0)
print(z.shape) # output: torch.Size([2, 2, 3])
2、torch.split()函数
torch.split()函数用于将输入张量沿着指定的维度分割为多个张量。它的基本语法如下所示:torch.split(tensor, split_size_or_sections, dim=0) -> List of Tensors
其中,tensor是要分割的输入张量,split_size_or_sections是分割的大小或者分割的位置,dim是分割的维度。例如:
import torch
x = torch.randn(2, 6)
y1, y2, y3 = torch.split(x, 2, dim=1)
print(y1.shape) # output: torch.Size([2, 2])
print(y2.shape) # output: torch.Size([2, 2])
print(y3.shape) # output: torch.Size([2, 2])
3、torch.chunk()函数
torch.chunk()函数是torch.split()函数的逆操作,用于将输入张量沿着指定的维度分割为多个张量。它的基本语法如下所示:torch.chunk(tensor, chunks, dim=0) -> List of Tensors
其中,tensor是要分割的输入张量,chunks是分割的块数,dim是分割的维度。例如:
import torch
x = torch.randn(2, 6)
y1, y2, y3 = torch.chunk(x, 3, dim=1)
print(y1.shape) # output: torch.Size([2, 2])
print(y2.shape) # output: torch.Size([2, 2])
print(y3.shape) # output: torch.Size([2, 2])