您的位置:

详解 Pytorch 中的 unsqueeze(0)

一、概述

在 Pytorch 中,我们经常需要处理不同维度的张量数据。unsqueeze() 方法就是用来增加张量的维度的,它会在指定位置增加一维。而其中的 unsqueeze(0) 就是在索引位置 0 上增加一维。

下面我们将从多个方面详细阐述 unsqueeze(0) 方法。

二、增加维度

unsqueeze(0) 的主要作用就是在张量最前面增加一维。

举个例子,我们有一个 1 维张量 tensor1 = torch.tensor([1, 2, 3]),如果我们想将其转换成 2 维张量,可以使用 unsqueeze(0) 方法,在索引位置 0 上增加一维。

import torch

tensor1 = torch.tensor([1, 2, 3])
tensor1_2d = tensor1.unsqueeze(0)
print(tensor1_2d.shape)   # 输出 torch.Size([1, 3])

可以看到,原先的 1 维张量变成了 2 维张量,第一个维度的大小变成了 1。

同理,我们还可以进行多次 unsqueeze(0) 操作,增加多个维度:

import torch

tensor1 = torch.tensor([1, 2, 3])
tensor2 = tensor1.unsqueeze(0).unsqueeze(0)
print(tensor2.shape)    # 输出 torch.Size([1, 1, 3])

可以看到,这次我们进行了两次 unsqueeze(0),在原先的基础上增加了两个维度。

三、在模型中的应用

unsqueeze(0) 方法在深度学习模型中也是常用的操作之一。比如,在卷积神经网络中,输入通常是 4 维张量,分别表示 batch_size, channel, height, width。

如果我们的数据集只有一张图片,那么 batch_size 就为 1。为了将数据集格式化成网络所需要的输入格式,我们就需要将单张图片的 3 维张量转换成 4 维张量。这时候 unsqueeze(0) 就能派上用场了。

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 10, kernel_size=3)
        
    def forward(self, input):
        x = input.unsqueeze(0)   # 将 3 维张量转换成 4 维张量
        out = self.conv(x)
        return out
        
net = Net()
input = torch.randn(1, 28, 28)
output = net(input)
print(output.shape)   # 输出 torch.Size([1, 10, 26, 26])

可以看到,通过 unsqueeze(0),我们将输入张量从 3 维转换成了 4 维,成功地将数据集格式化成了网络所需要的输入格式。

四、拼接操作

unsqueeze(0) 方法还能和其他张量拼接操作一起使用。

比如,我们有两个 2 维张量 tensor1 和 tensor2,如果想在第一个维度上进行拼接,就需要对它们进行 unsqueeze(0) 操作,然后再使用 cat() 方法进行拼接。

import torch

tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 在第一个维度上进行拼接
tensor3 = torch.cat((tensor1.unsqueeze(0), tensor2.unsqueeze(0)), dim=0)
print(tensor3.shape)   # 输出 torch.Size([2, 2, 3])

可以看到,通过 unsqueeze(0) 和 cat() 方法,我们成功地在第一个维度上将两个 2 维张量拼接成了一个 3 维张量。

五、实现 broadcast_to

unsqueeze(0) 还能用来实现 broadcast_to 操作。broadcast_to 操作是指将一个张量的形状扩展成指定的形状。

import torch

def broadcast_to(input, shape):
    # 先求出原始形状和目标形状的差距
    diff = len(shape) - len(input.shape)
    # 在 input 最前面增加与目标形状相差的维数个维度
    for _ in range(diff):
        input = input.unsqueeze(0)
    # 使用 expand 方法扩展形状
    return input.expand(shape)

x = torch.tensor([1, 2, 3])
y = broadcast_to(x, [2, 3])
print(y)

可以看到,使用 unsqueeze(0) 和 expand() 方法,我们成功地将 1 维张量 x 扩展成了形状为 [2, 3] 的张量 y。

六、总结

unsqueeze(0) 方法是 Pytorch 中常用的增加张量维度的方法之一。它能在指定位置上增加一维,可以与其他拼接操作一起使用,也可以用来实现 broadcast_to 操作。在深度学习模型中,使用 unsqueeze(0) 能够方便地将数据集格式化成网络所需要的输入格式。

使用 unsqueeze(0) 方法需要注意,增加的维度大小是 1,如果需要增加其他大小的维度,需要使用 unsqueeze() 方法,并制定对应的索引位置。