一、imagefolder函数
一般用于处理图像的数据集,包括从文件夹读取图像,对图像进行预处理,并生成可供pytorch训练的数据结构。imagefolder函数常见的参数及含义如下:
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None)
root:数据集的根目录,在该目录下以文件夹的形式存放着各个类别的图像。
transform:图像的转换,常用的操作有:裁剪、旋转、翻转、缩放等等,这些操作可以通过该参数进行实现。
target_transform:标签的转换,常用的操作有:独热编码(one-hot)、标签映射等等,这些操作可以通过该参数进行实现。
loader:读取图像的方法,默认为default_loader方法。
is_valid_file:文件是否有效的判断方法,默认为默认方法。
二、imagefolder加载的数据类型
在使用imagefolder函数时,经常会涉及到数据类型的转换。其中常用的数据类型有Tensor和numpy数组。Tensor是pytorch中最常见的数据类型,numpy数组则是python中处理科学计算最常用的数据类型。
三、imagefolder数据集
使用imagefolder函数可以获取到的数据集对象类型为torch.utils.data.Dataset,它是pytorch中表示数据集的一个类,由该类构建的对象可以被pytorch中的DataLoader所使用。数据集对象通常包含三个方法:
__len__(): 获取数据集的长度
__getitem__(): 获取某个数据的索引
classes: 获取数据集的类别
四、imagefolder怎么用
使用imagefolder函数生成的数据对象可以被pytorch中的DataLoader所使用,如下所示:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
root = "/path/to/dataset"
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
dataset = datasets.ImageFolder(root=root, transform=transform)
dataloader = DataLoader(dataset, batch_size=64,
shuffle=True, num_workers=4)
五、imagefolder混淆矩阵
混淆矩阵是在分类问题中常用的评价方法,其主要思想是将实际类别与预测类别分别作为行和列,矩阵中每一个元素表示将某个实际类别预测为某个预测类别的样本数。在pytorch中,使用imagefolder加载数据时,可以通过如下方法生成混淆矩阵:
from sklearn.metrics import confusion_matrix
# 加载数据集
test_data = datasets.ImageFolder(test_dir, transform=valid_transform)
# 获取预测结果
y_pred = []
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
_, pred = torch.max(outputs, 1)
y_pred += list(pred.cpu().numpy())
# 获取实际结果
y_true = []
for _, labels in test_loader:
y_true += list(labels.numpy())
# 生成混淆矩阵
conf_mat = confusion_matrix(y_true, y_pred)
print(conf_mat)
六、imagefolder划分交叉验证
在使用imagefolder函数时,经常需要将数据集划分为训练集、验证集和测试集。pytorch提供了SubsetRandomSampler和SubsetSequentialSampler用于实现手动划分训练集、验证集和测试集。
from torch.utils.data import SubsetRandomSampler, SubsetSequentialSampler
# 定义训练、验证和测试数据集比例
train_ratio = 0.8
valid_ratio = 0.1
test_ratio = 0.1
# 定义数据集对象
data = datasets.ImageFolder(root_dir, transform=transform)
# 定义数据集的长度和索引
data_len = len(data)
indices = list(range(data_len))
# 计算训练、验证和测试数据集的长度
train_size = int(train_ratio * data_len)
valid_size = int(valid_ratio * data_len)
test_size = data_len - train_size - valid_size
# 分割训练、验证和测试数据集
train_idx, valid_idx, test_idx = indices[:train_size], indices[train_size:train_size+valid_size], indices[-test_size:]
# 创建数据集的Sampler
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetSequentialSampler(valid_idx)
test_sampler = SubsetSequentialSampler(test_idx)
# 创建数据集的DataLoader
train_loader = DataLoader(data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
valid_loader = DataLoader(data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)
test_loader = DataLoader(data, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers)
七、imagefolder处理tif图片
在使用imagefolder函数时,经常会遇到处理扩展名为tif的图像文件的情况。pytorch中提供了PIL库用于处理图像,可以通过该库的Image.open方法打开tif图片。
from PIL import Image
class TIFFImageLoader(object):
def __call__(self, filename):
img = Image.open(filename)
img.load()
return img
data_transforms = transforms.Compose([transforms.CenterCrop(1000), transforms.ToTensor()])
trainset = datasets.ImageFolder(root=os.path.join(train_path), transform=data_transforms, loader=TIFFImageLoader())
八、imagefolder怎么分训练集与测试集选取
在使用imagefolder函数时,经常需要将数据集划分为训练集和测试集。pytorch提供了random_split方法用于实现随机划分训练集和测试集。
from torchvision import transforms, datasets
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
data = datasets.ImageFolder(root="/path/to/data", transform=data_transforms)
train_size = int(len(data) * 0.8)
test_size = len(data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)