.load_state_dict方法详解

发布时间:2023-05-19

一、.load_state_dict 的介绍

load_state_dict 方法是 PyTorch 中一个十分重要的方法,它可以将预训练模型的状态字典加载到新的模型中。模型的状态字典包含了模型的参数和缓冲器。 该方法的作用是加载参数和缓冲器,并且使用严格的参数匹配,如果有对应不上的参数,会报错。

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
                    strict: bool = True) -> None:
    r"""Loads a model's parameter dictionary (state_dict).
    Arguments:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        strict (bool, optional): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`state_dict` function. Default: ``True``
    Returns:
        None
    .. note::
        The :attr:`strict` parameter has home-field advantage here. See the
        note in :meth:`torch.nn.Module.load_state_dict` for a
        description of how it's used.
    """

二、.load_state_dict 方法的应用场景

.load_state_dict 方法是在训练中使用预训练模型时常用的方法。预训练模型的状态字典不能直接复制到一个新模型中,需要使用 .load_state_dict 方法来恢复模型。 在迁移学习中,我们可以使用已训练好的模型,将其参数作为新模型的初始参数,然后再在该基础上进行训练,从而加速我们的训练过程,提高模型的性能。 下面是一段使用 .load_state_dict 方法加载预训练模型并用来进行测试的代码:

import torch
import torch.nn as nn
import torchvision.models as models
model = models.resnet18(pretrained=True)
fc_inputs = model.fc.in_features
model.fc = nn.Sequential(
            nn.Linear(fc_inputs, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 10))
model.load_state_dict(torch.load('resnet18.pth'))
# test the model
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
print(outputs.shape)

三、.load_state_dict 方法的常用参数

1. state_dict 参数

state_dict 是一个包含了参数和缓冲器的字典。这个字典可以从一个已经训练好的模型中获取,也可以通过 state_dict() 方法获取当前模型的参数字典。 例如:

model = torchvision.models.resnet18(pretrained=True)
state_dict = model.state_dict()

2. strict 参数

strict 参数是一个布尔类型的值,用于标记是否使用严格的参数匹配。

  • 如果 strict=True,则 state_dict 中的参数名称必须与新模型中的参数名称完全匹配,否则会报错。
  • 如果 strict=False,则新模型中没有指定的参数,就忽略掉,而不会报错。

四、.load_state_dict 方法的注意事项

1. 模型的架构需要保持一致

.load_state_dict 方法的使用需要注意模型的架构必须与原始模型的架构完全相同,否则将无法加载参数。如果想要更改模型的架构,可以使用 torch.nn.Sequential() 重新构造模型。

2. 加载预训练模型需要正确指定路径

如果我们需要加载一个预训练模型,需要正确指定预训练模型的位置。一般来说,预训练模型被保存为一个 .pth 文件。如果 .pth 文件和模型代码不在同一个文件夹中,则需要使用正确的路径来加载模型。

# 模型保存在 model 文件夹中的 resnet18.pth 文件中
model = models.resnet18(pretrained=True)
model.load_state_dict(torch.load('model/resnet18.pth'))

3. .load_state_dict 方法与 .freeze_layers() 方法的配合使用

当使用预训练模型进行迁移学习时,我们常常需要固定一些层的参数,只更新特定的层。在这种情况下,我们可以使用 .freeze_layers() 方法来冻结层的参数,在反向传播时不进行参数更新。在 .load_state_dict() 方法中,我们需要排除掉已冻结的层,否则这些层的参数将会被加载进去。 例如:

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
# 假设已经冻结了卷积层的参数
params_to_update = []
for name, param in model.named_parameters():
    if '.bn' not in name:
        params_to_update.append(param)
optimizer = torch.optim.Adam(params_to_update)

在以上代码中,.freeze_layers() 方法已经冻结了所有的卷积层,现在我们只更新全连接层的参数。所以在 .load_state_dict() 方法中,我们需要指定只加载全连接层的参数:

model.load_state_dict(torch.load('model_weights.pth'), strict=False)

五、总结

在本文中,我们详细讲解了 PyTorch 中 .load_state_dict() 方法的使用方法及注意事项。通过本文的介绍,我们可以清楚地知道如何在训练中使用预训练模型,并且了解了一些需要注意的问题。