您的位置:

从多个方面详解load_state_dict方法

一、功能概述

load_state_dict是PyTorch中一个非常重要的方法,它可以将一个已经训练好的模型的参数加载到另一个同样结构的模型中。在实际使用中,它经常用于预训练模型的迁移学习、模型参数的恢复等场景。在这一部分,我们将介绍load_state_dict方法的基本用法以及其调用的原理。

  model_dict = model.state_dict()  # 此时model还未更新过,其参数未被优化器更改
  pretrained_dict = torch.load(PATH)
  
  # filter out unnecessary keys
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  
  # overwrite entries in the existing state dict
  model_dict.update(pretrained_dict) 
  model.load_state_dict(model_dict)

二、参数说明

load_state_dict方法有一个必要的参数,即pretrained_dict,表示已经训练好的模型的参数,它是一个Python字典。该参数需要满足以下两个要求:

1、字典的键值对应着模型中各层的名称

2、字典的值是一个已经训练好的张量

在使用时需要注意,预训练模型和目标模型的结构必须一致。

三、基本用法

load_state_dict方法的基本用法非常简单,只需要通过Python字典构造函数构造一个预训练模型的参数字典,然后使用load_state_dict方法将其加载到目标模型中即可。下面是一段简单的示例代码:

  model = Net()
  pretrained_dict = torch.load(PATH)
  model.load_state_dict(pretrained_dict)

四、加载部分参数

在有些情况下,我们只需要加载模型的部分参数。例如,我们想仅加载预训练模型中某些层的参数而保持目标模型中其他层的参数不变。在这种情况下,需要将pretrained_dict中不需要的部分剔除,可以使用Python字典的推导式来完成这一操作:

  model_dict = model.state_dict()
  pretrained_dict = torch.load(PATH)
  
  # filter out unnecessary keys
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  
  # overwrite entries in the existing state dict
  model_dict.update(pretrained_dict) 
  model.load_state_dict(model_dict)

五、跨设备加载

在使用load_state_dict方法时,需要注意张量的设备类型和ID。如果预训练模型和目标模型的设备类型或ID不同,就需要对预训练模型中的参数进行相应的修改才能使其被成功加载。下面是一段示例代码:

  model = nn.DataParallel(model)
  pretrained_dict = torch.load(PATH)
  
  # create new OrderedDict that does not contain `module.`
  from collections import OrderedDict
  new_state_dict = OrderedDict()
  for k, v in pretrained_dict.items():
      name = k[7:] # remove `module.`
      new_state_dict[name] = v
  
  # load params
  model.load_state_dict(new_state_dict)

六、加载到指定的层

有时候,我们可能只需要把预训练模型的部分参数加载到目标模型的指定层中,而不需要覆盖整个目标模型的参数。在这种情况下,我们需要手动获取指定层的state_dict,并将预训练模型中对应的参数赋值给该state_dict。下面是一段示例代码:

  model = Net()
  pretrained_dict = torch.load(PATH)
  
  # get the dict of a module
  net_dict = model.net.state_dict()
  pretrained_dict = {'.'.join(k.split('.')[1:]): v for k, v in pretrained_dict.items() if k.split('.')[1] == 'net'}
  
  # overwrite entries in the state dict for this module
  net_dict.update(pretrained_dict)
  
  model.net.load_state_dict(net_dict)