深度学习框架torch.unsqueeze函数的使用方法

发布时间:2023-05-18

一、什么是torch.unsqueeze函数?

在深度学习中,我们经常需要改变张量(tensor)的形状以满足不同的需求,例如,改变张量的维度、改变张量的顺序等等。而函数torch.unsqueeze就是用来改变张量维度的常用函数之一,它可以在张量的某个维度上增加一个维度。具体来说,unsqueeze函数会在该维度前插入一个为1的维度。以下是unsqueeze函数的语法:

torch.unsqueeze(input, dim)

其中input表示输入的张量(必须是torch.tensor类型),而dim则表示要在哪个维度上增加一个维度。在这里要特别注意,dim的取值范围必须在[-input.ndim - 1, input.ndim]之间(其中input.ndim表示输入张量的维度)。

二、torch.unsqueeze函数的使用方法

1. 在二维张量中增加一个维度

我们先从一个简单的例子开始,给定一个二维张量input,我们要在其第二个维度上增加一个维度。具体操作如下:

import torch
input = torch.tensor([[1, 2, 3], [4, 5, 6]])
output = torch.unsqueeze(input, dim=1)
print(output.shape)

运行结果为:

torch.Size([2, 1, 3])

可以看到,通过调用unsqueeze函数,在第二个维度上插入了一个维度,将原来的二维张量变成了三维。

2. 在高维张量中增加一个维度

除了在二维张量中增加一个维度之外,unsqueeze函数同样适用于高维张量。例如,给定一个三维张量input,我们要在其第一个维度上增加一个维度:

import torch
input = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
output = torch.unsqueeze(input, dim=0)
print(output.shape)

运行结果为:

torch.Size([1, 2, 2, 3])

可以看到,通过在第一个维度上插入一个维度,将原来的三维张量变成了四维。

3. 在张量中间的维度上增加一个维度

除了在张量的最前面或最后面增加维度之外,在中间的维度上增加维度同样也是非常有用的。例如,给定一个四维张量input,我们要在其第三个维度(从零开始计数)上增加一个维度:

import torch
input = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]])
output = torch.unsqueeze(input, dim=2)
print(output.shape)

运行结果为:

torch.Size([2, 2, 1, 2, 2])

可以看到,通过在第三个维度上插入一个维度,将原来的四维张量变成了五维。

三、小结

本文主要介绍了深度学习框架torch.unsqueeze函数的使用方法。我们可以通过这个函数来增加张量的维度,从而满足不同的需求。在使用该函数时,需要注意dim参数的设置,以避免出现维度设置错误的问题。