您的位置:

详解transforms.topilimage()函数

一、函数介绍

transforms.topilimage()是pytorch中一个非常实用的函数,可以将Tensor类型的图像转换成PIL.Image类型的图像,使得我们可以在python中更方便地处理图像。

该函数需要注意的是,输入的Tensor需要满足以下条件:1.维度为3或4;2.数据类型为torch.uint8或者torch.float32;3.数值范围在[0, 1]或[0, 255]之间。

下面我们将从几个方面来详细介绍这个函数。

二、函数参数

transforms.topilimage()函数只有一个参数,即输入的Tensor类型的图像数据,这个参数是必须的。

import torch
from torchvision import transforms

# 构造一个3通道的图片,大小为3x3
image = torch.rand(3, 3, 3)

# 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像
pil_image = transforms.ToPILImage()(image)

三、函数用途

transforms.topilimage()函数用途非常广泛,可以在很多场景下使用。

1、可视化图像

在使用pytorch进行图像分类任务的时候,我们通常使用matplotlib等库来进行可视化,而transforms.topilimage函数可以很方便地将Tensor类型的图像转换成PIL.Image类型,方便我们进行可视化。

import torch
from torchvision import transforms
from PIL import Image

# 构造一个3通道的图片,大小为3x3
image = torch.rand(3, 3, 3)

# 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像
pil_image = transforms.ToPILImage()(image)

# 可以使用matplotlib来进行可视化
import matplotlib.pyplot as plt
plt.imshow(pil_image)
plt.show()

2、保存图像

当我们需要将pytorch中的Tensor类型的图像保存成图片的时候,transforms.topilimage函数可以很方便地完成这个任务。

import torch
import torchvision.transforms as transforms
from PIL import Image

# 构造一个3通道的图片,大小为3x3
image = torch.rand(3, 3, 3)

# 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像
pil_image = transforms.ToPILImage()(image)

# 将PIL.Image类型的图像保存成文件
pil_image.save('test.png')

四、函数返回值

transforms.topilimage()函数的返回值是一个PIL.Image类型的图像。

import torch
from torchvision import transforms
from PIL import Image

# 构造一个3通道的图片,大小为3x3
image = torch.rand(3, 3, 3)

# 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像
pil_image = transforms.ToPILImage()(image)

# 打印返回值的类型
print(type(pil_image)) # 

  

五、函数示例

下面的例子展示了如何读取一张图像,对其进行随机裁剪、水平翻转和归一化操作,最终将处理后的Tensor类型的图像转换成PIL.Image类型,并将其保存成文件。

from PIL import Image
import torch
from torchvision import transforms

# 读取一张图像
image = Image.open('test.jpg')

# 构造一个transforms对象,包含随机裁剪、水平翻转和归一化操作
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 对图像进行变换
tensor_image = transform(image)

# 将Tensor类型的图像转换成PIL.Image类型的图像
pil_image = transforms.ToPILImage()(tensor_image)

# 保存PIL.Image类型的图像
pil_image.save('result.png')