一、.load_state_dict
的介绍
load_state_dict
方法是 PyTorch 中一个十分重要的方法,它可以将预训练模型的状态字典加载到新的模型中。模型的状态字典包含了模型的参数和缓冲器。
该方法的作用是加载参数和缓冲器,并且使用严格的参数匹配,如果有对应不上的参数,会报错。
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True) -> None:
r"""Loads a model's parameter dictionary (state_dict).
Arguments:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`state_dict` function. Default: ``True``
Returns:
None
.. note::
The :attr:`strict` parameter has home-field advantage here. See the
note in :meth:`torch.nn.Module.load_state_dict` for a
description of how it's used.
"""
二、.load_state_dict
方法的应用场景
.load_state_dict
方法是在训练中使用预训练模型时常用的方法。预训练模型的状态字典不能直接复制到一个新模型中,需要使用 .load_state_dict
方法来恢复模型。
在迁移学习中,我们可以使用已训练好的模型,将其参数作为新模型的初始参数,然后再在该基础上进行训练,从而加速我们的训练过程,提高模型的性能。
下面是一段使用 .load_state_dict
方法加载预训练模型并用来进行测试的代码:
import torch
import torch.nn as nn
import torchvision.models as models
model = models.resnet18(pretrained=True)
fc_inputs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(fc_inputs, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 10))
model.load_state_dict(torch.load('resnet18.pth'))
# test the model
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
print(outputs.shape)
三、.load_state_dict
方法的常用参数
1. state_dict
参数
state_dict
是一个包含了参数和缓冲器的字典。这个字典可以从一个已经训练好的模型中获取,也可以通过 state_dict()
方法获取当前模型的参数字典。
例如:
model = torchvision.models.resnet18(pretrained=True)
state_dict = model.state_dict()
2. strict
参数
strict
参数是一个布尔类型的值,用于标记是否使用严格的参数匹配。
- 如果
strict=True
,则state_dict
中的参数名称必须与新模型中的参数名称完全匹配,否则会报错。 - 如果
strict=False
,则新模型中没有指定的参数,就忽略掉,而不会报错。
四、.load_state_dict
方法的注意事项
1. 模型的架构需要保持一致
.load_state_dict
方法的使用需要注意模型的架构必须与原始模型的架构完全相同,否则将无法加载参数。如果想要更改模型的架构,可以使用 torch.nn.Sequential()
重新构造模型。
2. 加载预训练模型需要正确指定路径
如果我们需要加载一个预训练模型,需要正确指定预训练模型的位置。一般来说,预训练模型被保存为一个 .pth
文件。如果 .pth
文件和模型代码不在同一个文件夹中,则需要使用正确的路径来加载模型。
# 模型保存在 model 文件夹中的 resnet18.pth 文件中
model = models.resnet18(pretrained=True)
model.load_state_dict(torch.load('model/resnet18.pth'))
3. .load_state_dict
方法与 .freeze_layers()
方法的配合使用
当使用预训练模型进行迁移学习时,我们常常需要固定一些层的参数,只更新特定的层。在这种情况下,我们可以使用 .freeze_layers()
方法来冻结层的参数,在反向传播时不进行参数更新。在 .load_state_dict()
方法中,我们需要排除掉已冻结的层,否则这些层的参数将会被加载进去。
例如:
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
# 假设已经冻结了卷积层的参数
params_to_update = []
for name, param in model.named_parameters():
if '.bn' not in name:
params_to_update.append(param)
optimizer = torch.optim.Adam(params_to_update)
在以上代码中,.freeze_layers()
方法已经冻结了所有的卷积层,现在我们只更新全连接层的参数。所以在 .load_state_dict()
方法中,我们需要指定只加载全连接层的参数:
model.load_state_dict(torch.load('model_weights.pth'), strict=False)
五、总结
在本文中,我们详细讲解了 PyTorch 中 .load_state_dict()
方法的使用方法及注意事项。通过本文的介绍,我们可以清楚地知道如何在训练中使用预训练模型,并且了解了一些需要注意的问题。