您的位置:

从多个角度深入解析importtorchvision

一、常见的importtorchvision模块

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.utils as utils

torchvision是PyTorch的一个计算机视觉包,其中最为常用的模块有transforms、datasets、models、utils。transforms模块提供了常用的图像预处理方法;datasets模块提供了常见的视觉数据集(如CIFAR10、MNIST等);models模块提供了经典的预训练模型(如ResNet、VGG等);utils模块提供了一些常用的工具函数。

使用这些模块,我们可以方便地搭建计算机视觉领域的深度学习模型。

二、利用transforms模块进行数据预处理

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

transforms模块提供了方便的图像预处理方法,可以在训练网络前对输入数据进行必要的处理。在上述示例中,我们使用了transforms.Compose方法将多个变换组合在一起来处理输入数据。其中:

  • transforms.RandomCrop随机裁剪图像
  • transforms.RandomHorizontalFlip随机水平翻转图像
  • transforms.ToTensor将图像转换为张量类型
  • transforms.Normalize对图像张量进行标准化处理,即将图像张量减去均值再除以标准差,使得图像张量的值在(-1, 1)之间

三、通过datasets模块加载数据集

trainset = datasets.ImageFolder(root='path/to/data', 
                                  transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=4)

datasets模块提供了常见的视觉数据集,使用datasets.ImageFolder方法可以加载自己的图像数据集。其中,root参数指定图像数据的根目录,以文件夹的形式将不同类别的图像分别存储在不同的子目录中。transform参数指定对输入图像进行的预处理方法。

使用torch.utils.data.DataLoader方法可以将数据集转换为可供训练功能使用的批量数据。

四、使用models模块搭建深度学习模型

import torch.nn as nn
import torch.optim as optim

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

models模块提供了经典的预训练模型,使用这些模型可以快速搭建自己的深度学习模型。在上述示例中,我们使用了pretrained=True指定使用预训练的ResNet18模型,并且修改了全连接层的输出结构。nn.CrossEntropyLoss指定损失函数,optim.SGD指定优化器。

五、通过utils模块进行结果可视化

import numpy as np
import torchvision.transforms.functional as F

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()

imshow(utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(32)))

utils模块提供了一些常用的工具函数,例如可视化结果。在上述示例中,我们使用了utils.make_grid将输入数据制成网格图,imshow函数将该图可视化。结果如下所示:

  cat   dog   dog  deer  deer plane horse   dog  deer   cat  ship  frog truck  deer   dog truck horse  deer  deer   dog  deer truck   car   cat truck  deer  deer   car   dog truck   dog truck plane plane   car

六、总结

importtorchvision模块是PyTorch计算机视觉方向的重要组成部分,提供了丰富的预处理方法(transforms)、数据集加载方法(datasets)、预训练模型(models)以及工具函数(utils)。我们可以基于这些组件快速搭建自己的计算机视觉深度学习模型,并且通过utils模块提供的可视化工具进行结果展示。