一、概述
在 Pytorch 中,我们经常需要处理不同维度的张量数据。unsqueeze() 方法就是用来增加张量的维度的,它会在指定位置增加一维。而其中的 unsqueeze(0) 就是在索引位置 0 上增加一维。
下面我们将从多个方面详细阐述 unsqueeze(0) 方法。
二、增加维度
unsqueeze(0) 的主要作用就是在张量最前面增加一维。
举个例子,我们有一个 1 维张量 tensor1 = torch.tensor([1, 2, 3]),如果我们想将其转换成 2 维张量,可以使用 unsqueeze(0) 方法,在索引位置 0 上增加一维。
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor1_2d = tensor1.unsqueeze(0)
print(tensor1_2d.shape) # 输出 torch.Size([1, 3])
可以看到,原先的 1 维张量变成了 2 维张量,第一个维度的大小变成了 1。
同理,我们还可以进行多次 unsqueeze(0) 操作,增加多个维度:
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor2 = tensor1.unsqueeze(0).unsqueeze(0)
print(tensor2.shape) # 输出 torch.Size([1, 1, 3])
可以看到,这次我们进行了两次 unsqueeze(0),在原先的基础上增加了两个维度。
三、在模型中的应用
unsqueeze(0) 方法在深度学习模型中也是常用的操作之一。比如,在卷积神经网络中,输入通常是 4 维张量,分别表示 batch_size, channel, height, width。
如果我们的数据集只有一张图片,那么 batch_size 就为 1。为了将数据集格式化成网络所需要的输入格式,我们就需要将单张图片的 3 维张量转换成 4 维张量。这时候 unsqueeze(0) 就能派上用场了。
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=3)
def forward(self, input):
x = input.unsqueeze(0) # 将 3 维张量转换成 4 维张量
out = self.conv(x)
return out
net = Net()
input = torch.randn(1, 28, 28)
output = net(input)
print(output.shape) # 输出 torch.Size([1, 10, 26, 26])
可以看到,通过 unsqueeze(0),我们将输入张量从 3 维转换成了 4 维,成功地将数据集格式化成了网络所需要的输入格式。
四、拼接操作
unsqueeze(0) 方法还能和其他张量拼接操作一起使用。
比如,我们有两个 2 维张量 tensor1 和 tensor2,如果想在第一个维度上进行拼接,就需要对它们进行 unsqueeze(0) 操作,然后再使用 cat() 方法进行拼接。
import torch
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 在第一个维度上进行拼接
tensor3 = torch.cat((tensor1.unsqueeze(0), tensor2.unsqueeze(0)), dim=0)
print(tensor3.shape) # 输出 torch.Size([2, 2, 3])
可以看到,通过 unsqueeze(0) 和 cat() 方法,我们成功地在第一个维度上将两个 2 维张量拼接成了一个 3 维张量。
五、实现 broadcast_to
unsqueeze(0) 还能用来实现 broadcast_to 操作。broadcast_to 操作是指将一个张量的形状扩展成指定的形状。
import torch
def broadcast_to(input, shape):
# 先求出原始形状和目标形状的差距
diff = len(shape) - len(input.shape)
# 在 input 最前面增加与目标形状相差的维数个维度
for _ in range(diff):
input = input.unsqueeze(0)
# 使用 expand 方法扩展形状
return input.expand(shape)
x = torch.tensor([1, 2, 3])
y = broadcast_to(x, [2, 3])
print(y)
可以看到,使用 unsqueeze(0) 和 expand() 方法,我们成功地将 1 维张量 x 扩展成了形状为 [2, 3] 的张量 y。
六、总结
unsqueeze(0) 方法是 Pytorch 中常用的增加张量维度的方法之一。它能在指定位置上增加一维,可以与其他拼接操作一起使用,也可以用来实现 broadcast_to 操作。在深度学习模型中,使用 unsqueeze(0) 能够方便地将数据集格式化成网络所需要的输入格式。
使用 unsqueeze(0) 方法需要注意,增加的维度大小是 1,如果需要增加其他大小的维度,需要使用 unsqueeze() 方法,并制定对应的索引位置。