一、基本概念
ImageFolder是PyTorch中一个非常实用的类,它可以将一个文件夹中的图片按照预先定义好的transform操作转换为PyTorch中可以使用的Tensor。
示例代码:
import torch.utils.data as data
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
...
def __getitem__(self, index):
...
def __len__(self):
...
其中root指定需要读取的文件夹路径,transform指定需要进行的数据转换操作(如样本随机旋转、随机裁剪等),target_transform指定目标标签的转换操作,比如将字符串类型转换为数字类型;loader指定需要使用哪个文件读取器,默认为PIL.ImageLoader。
二、应用场景
ImageFolder可以方便地获取文件夹中的所有图片,并进行各种数据增强操作,通常应用于图像分类、目标检测、图像分割等领域。比如在图像分类中,可以通过ImageFolder将每一类图片存放在一个文件夹中,文件夹的名称即为类别名,提高了文件夹的结构化程度。
示例代码:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 对数据进行随机旋转和裁剪
train_transform = transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder('./train', transform=train_transform)
首先定义了一个train_transform,其中包括了随机旋转、随机裁剪、随机水平翻转等操作;接着使用ImageFolder加载文件夹中的所有图片,并将train_transform作为参数传入,实现对数据集的增强操作。
三、数据预处理
在使用ImageFolder时,需要对数据进行预处理,包括数据增强、归一化等操作,以加快训练速度、提升模型精度。具体可以根据不同的实际应用场景选择不同的预处理方式。
示例代码:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
以上代码定义了一个名为transform的预处理操作,首先将图片resize到256×256,然后进行中心裁剪到224×224,接着将图片转换成Tensor格式,并进行归一化操作。
四、可视化数据
ImageFolder还可以用于可视化数据,方便我们观察样本的特点和变化。以下代码展示了如何将10张图片随机可视化出来。
示例代码:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
# 加载数据集
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder('./train', transform=transform)
# 随机可视化数据集中的10张图片
def display_imgs(imgs, labels):
# 将Tensor转换为numpy数组
imgs = imgs.numpy().transpose((0, 2, 3, 1))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
imgs = std * imgs + mean
imgs = np.clip(imgs, 0, 1)
# 可视化图片
fig = plt.figure(figsize=(25, 20))
for i in range(10):
ax = fig.add_subplot(2, 5, i + 1, xticks=[], yticks=[])
ax.imshow(imgs[i])
ax.set_title(train_dataset.classes[labels[i]])
fig.show()
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10,
shuffle=True, num_workers=4)
data_iter = iter(data_loader)
imgs, labels = data_iter.next()
display_imgs(imgs, labels)
以上代码首先使用ImageFolder加载数据集并进行预处理,然后随机选择10张样本进行可视化。其中display_imgs函数用于将Tensor格式的图片转换为numpy数组,并进行可视化。最终结果为10张随机选择的图片以及其对应的类别。
五、小结
ImageFolder是PyTorch中一个非常实用的类,可以方便地读取文件夹中的图片,并进行数据增强、数据预处理、可视化等操作。在图像分类、目标检测、图像分割等领域中都有广泛的应用。