一、函数介绍
transforms.topilimage()是pytorch中一个非常实用的函数,可以将Tensor类型的图像转换成PIL.Image类型的图像,使得我们可以在python中更方便地处理图像。
该函数需要注意的是,输入的Tensor需要满足以下条件:1.维度为3或4;2.数据类型为torch.uint8或者torch.float32;3.数值范围在[0, 1]或[0, 255]之间。
下面我们将从几个方面来详细介绍这个函数。
二、函数参数
transforms.topilimage()函数只有一个参数,即输入的Tensor类型的图像数据,这个参数是必须的。
import torch from torchvision import transforms # 构造一个3通道的图片,大小为3x3 image = torch.rand(3, 3, 3) # 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像 pil_image = transforms.ToPILImage()(image)
三、函数用途
transforms.topilimage()函数用途非常广泛,可以在很多场景下使用。
1、可视化图像
在使用pytorch进行图像分类任务的时候,我们通常使用matplotlib等库来进行可视化,而transforms.topilimage函数可以很方便地将Tensor类型的图像转换成PIL.Image类型,方便我们进行可视化。
import torch from torchvision import transforms from PIL import Image # 构造一个3通道的图片,大小为3x3 image = torch.rand(3, 3, 3) # 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像 pil_image = transforms.ToPILImage()(image) # 可以使用matplotlib来进行可视化 import matplotlib.pyplot as plt plt.imshow(pil_image) plt.show()
2、保存图像
当我们需要将pytorch中的Tensor类型的图像保存成图片的时候,transforms.topilimage函数可以很方便地完成这个任务。
import torch import torchvision.transforms as transforms from PIL import Image # 构造一个3通道的图片,大小为3x3 image = torch.rand(3, 3, 3) # 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像 pil_image = transforms.ToPILImage()(image) # 将PIL.Image类型的图像保存成文件 pil_image.save('test.png')
四、函数返回值
transforms.topilimage()函数的返回值是一个PIL.Image类型的图像。
import torch from torchvision import transforms from PIL import Image # 构造一个3通道的图片,大小为3x3 image = torch.rand(3, 3, 3) # 使用transforms.topilimage()函数将Tensor类型的图像转换成PIL.Image类型的图像 pil_image = transforms.ToPILImage()(image) # 打印返回值的类型 print(type(pil_image)) #
五、函数示例
下面的例子展示了如何读取一张图像,对其进行随机裁剪、水平翻转和归一化操作,最终将处理后的Tensor类型的图像转换成PIL.Image类型,并将其保存成文件。
from PIL import Image import torch from torchvision import transforms # 读取一张图像 image = Image.open('test.jpg') # 构造一个transforms对象,包含随机裁剪、水平翻转和归一化操作 transform = transforms.Compose([ transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 对图像进行变换 tensor_image = transform(image) # 将Tensor类型的图像转换成PIL.Image类型的图像 pil_image = transforms.ToPILImage()(tensor_image) # 保存PIL.Image类型的图像 pil_image.save('result.png')