一、.load_state_dict的介绍
.load_state_dict方法是PyTorch中一个十分重要的方法,它可以将预训练模型的状态字典加载到新的模型中。模型的状态字典包含了模型的参数和缓冲器
该方法的作用是加载参数和缓冲器,并且使用严格的参数匹配,如果有对应不上的参数,会报错。
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True) -> None: r"""Loads a model's parameter dictionary (state_dict). Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`state_dict` function. Default: ``True`` Returns: None .. note:: The :attr:`strict` parameter has home-field advantage here. See the note in :meth:`torch.nn.Module.load_state_dict` for a description of how it's used. """
二、.load_state_dict方法的应用场景
.load_state_dict方法是在训练中使用预训练模型时常用的方法。预训练模型的状态字典不能直接复制到一个新模型中,需要使用.load_state_dict方法来恢复模型。
在迁移学习中,我们可以使用已训练好的模型,将其参数作为新模型的初始参数,然后再在该基础上进行训练,从而加速我们的训练过程,提高模型的性能。
下面是一段使用.load_state_dict方法加载预训练模型并用来进行测试的代码:
import torch import torch.nn as nn import torchvision.models as models model = models.resnet18(pretrained=True) fc_inputs = model.fc.in_features model.fc = nn.Sequential( nn.Linear(fc_inputs, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 10)) model.load_state_dict(torch.load('resnet18.pth')) # test the model inputs = torch.randn(1, 3, 224, 224) outputs = model(inputs) print(outputs.shape)
三、.load_state_dict方法的常用参数
1、state_dict参数
state_dict是一个包含了参数和缓冲器的字典。这个字典可以从一个已经训练好的模型中获取,也可以通过state_dict()方法获取当前模型的参数字典。
例如:
model = torchvision.models.resnet18(pretrained=True) state_dict = model.state_dict()
2、strict参数
strict参数是一个布尔类型的值,用于标记是否使用严格的参数匹配。
如果strict=True,则state_dict中的参数名称必须与新模型中的参数名称完全匹配,否则会报错。
如果strict=False,则新模型中没有指定的参数,就忽略掉,而不会报错。
四、.load_state_dict方法的注意事项
1、模型的架构需要保持一致
.load_state_dict方法的使用需要注意模型的架构必须与原始模型的架构完全相同,否则将无法加载参数。如果想要更改模型的架构,可以使用torch.nn.Sequential()重新构造模型。
2、加载预训练模型需要正确指定路径
如果我们需要加载一个预训练模型,需要正确指定预训练模型的位置。一般来说,预训练模型被保存为一个.pth文件。如果.pth文件和模型代码不在同一个文件夹中,则需要使用正确的路径来加载模型。
# 模型保存在model文件夹中的resnet18.pth文件中 model = models.resnet18(pretrained=True) model.load_state_dict(torch.load('model/resnet18.pth'))
3、.load_state_dict方法与.freeze_layers()方法的配合使用
当使用预训练模型进行迁移学习时,我们常常需要固定一些层的参数,只更新特定的层。在这种情况下,我们可以使用.freeze_layers()方法来冻结层的参数,在反向传播时不进行参数更新。在.load_state_dict()方法中,我们需要排除掉已冻结的层,否则这些层的参数将会被加载进去。
例如:
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False num_features = model.fc.in_features model.fc = nn.Linear(num_features, 10) # 假设已经冻结了卷积层的参数 params_to_update = [] for name, param in model.named_parameters(): if '.bn' not in name: params_to_update.append(param) optimizer = torch.optim.Adam(params_to_update)
在以上代码中,.freeze_layers()方法已经冻结了所有的卷积层,现在我们只更新全连接层的参数。所以在.load_state_dict()方法中,我们需要指定只加载全连接层的参数:
model.load_state_dict(torch.load('model_weights.pth'), strict=False)
五、总结
在本文中,我们详细讲解了PyTorch中.load_state_dict()方法的使用方法及注意事项。通过本文的介绍,我们可以清楚地知道如何在训练中使用预训练模型,并且了解了一些需要注意的问题。