前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >(转载非原创)全面解析Pytorch框架下模型存储,加载以及冻结

(转载非原创)全面解析Pytorch框架下模型存储,加载以及冻结

作者头像
xlj
修改2021-07-08 14:13:22
6390
修改2021-07-08 14:13:22
举报
文章被收录于专栏:XLJ的技术专栏

最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题。首先咱们先定义一个网络来进行后续的分析:

1、本文通用的网络模型
代码语言:javascript
复制
import torch
import torch.nn as nn
'''
定义网络中第一个网络模块 Net1
'''
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        
        # input size [B, 1, 3, 3] ==> [B, 1, 3, 3]
        self.n = nn.Conv2d(1, 2, 3, padding=1)
    def forward(self, x):
        x = self.n(x)
        return x
'''
定义网络中第二个网络模块 Net2
'''
class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.n = nn.Sequential(
            # input size [B, 1, 3, 3] ==> [B, 2, 3, 3]
            nn.Conv2d(2, 2, 3, padding=1),
            
            # input size [B, 2, 3, 3] ==> [B, 1, 1, 1]
            nn.Conv2d(2, 1, 3, padding=0),
            )
    def forward(self, x):
        x = self.n(x)
        return x
'''
定义网络中主网络模块 Network
'''
class Network(nn.Module):
    def __init__(self):
        super().__init__()     
        self.head = Net1()
        self.tail = Net2()   
    def forward(self, x):
        x = self.head(x)
        x = self.tail(x)
        return x

网络模块已经搭建好,我们先实例化一个模型然后打印看一下网络结构是否正确:

代码语言:javascript
复制
model = Network()	# 实例化网络模型
print(model)	# 输出网络结构
Input = torch.randn(1,1,3,3)	# 自定义数据输入
Output = model(Input)	# 计算网络输出
print("Input 的维度为:{},Output 的维度为:{}".format(Input.shape, Output.shape))

则输出结果为:

代码语言:javascript
复制
Network(
  (head): Net1(
    (n): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (tail): Net2(
    (n): Sequential(
      (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
    )
  )
)
Input 的维度为:torch.Size([1, 1, 3, 3]),Output 的维度为:torch.Size([1, 1, 1, 1])

从输出结果看,网络包含两个子模块 headtail,这两个子模块分别是类 Net1Net2 的实例化对象。在 Net2 的定义中,使用了 nn.Sequential() 函数,它能够将包含在里面的网络按照输入顺序进行组合,封装成一个新的模块,适用于网络中大量重复的结构,比如 Conv-ReLU-Conv 等模块。

2、对模型进行训练得到权重

我们先对网络做一个简单的训练,训练代码如下:

代码语言:javascript
复制
model = Network()	# 实例化网络模型
print(model) # 输出网络结构

torch.manual_seed(0) # 固定随机种子,确保每次产生的随机输入一致,方便我们评估训练结果
Input = torch.randn(1,1,3,3) # 自定义数据输入

Iter_num = 10	# 定义最大的迭代次数
Label = torch.tensor(1.0) # 定义有监督训练的label,这里的label必须是float类型的Tensor,否则会出错
criterion = nn.MSELoss()	# 定义损失函数,这里选用MSE

import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr = 0.01)	#定义优化器,这里采用随机梯度下降(SGD)

for index in range(Iter_num):
    Output = model(Input)	# 计算网络输出
    loss = criterion(Output, Label) # 计算loss
    loss.backward()	# 反向传播计算梯度
    optimizer.step()	# 梯度更新
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))

训练过程如下:

代码语言:javascript
复制
Iter:0/10	loss:1.4089158773422241	Output:tensor([[[[-0.1870]]]])
Iter:1/10	loss:1.3796569108963013	Output:tensor([[[[-0.1746]]]])
Iter:2/10	loss:1.323099136352539	Output:tensor([[[[-0.1503]]]])
Iter:3/10	loss:1.2428957223892212	Output:tensor([[[[-0.1149]]]])
Iter:4/10	loss:1.143916130065918	Output:tensor([[[[-0.0695]]]])
Iter:5/10	loss:1.0316702127456665	Output:tensor([[[[-0.0157]]]])
Iter:6/10	loss:0.9117376208305359	Output:tensor([[[[0.0452]]]])
Iter:7/10	loss:0.7892979979515076	Output:tensor([[[[0.1116]]]])
Iter:8/10	loss:0.6688111424446106	Output:tensor([[[[0.1822]]]])
Iter:9/10	loss:0.5538586378097534	Output:tensor([[[[0.2558]]]])
3、模型存储
3.1 模型参数一起存储与加载
代码语言:javascript
复制
'''
这种方式存储模型的参数,而非整个模型
'''
torch.save(model.state_dict(), model_path)	# 存储网络模型的参数
checkpoint = torch.load(model_path)	# 先加载模型的参数
model.load_state_dict(checkpoint)	# 再将加载的参数填入实例化的网络模型中
'''
这种方式存储整个模型
'''
torch.save(model,model_path)	# 直接存储整个模型,包括模型结构和参数
model = torch.load(model_path)	# 不用实例化,直接加载就可以用

存储整个模型与存储模型参数的区别:

  1. 整个模型:是保存整个网络结构和参数,使用时会加载结构和其中的参数,即边搭框架边填充参数;
  2. 仅参数:仅保存网络模型中的参数,在使用时需要先用训练时的模型实例化,再往里面填入参数,即需要先搭好框架再往框架里填参数。

下面我们就分别通过这两种方式进行模型存储与加载:

代码语言:javascript
复制
model_path_dict = './ckpt_dict.pth'	# 模型参数的存储路径
torch.save(model.state_dict(), model_path_dict)

model_path_model = './ckpt_model.pth' # 整个模型的存储路径
torch.save(model, model_path_model)

model_test = Network()	# 重新实例化一个网络对象
test_out = model_test(Input)	# 先看一下初始化输出
print("test_out: ", test_out.data)
 
checkpoint = torch.load(model_path_dict)	# 采用加载参数的方式加载与训练模型
model_test.load_state_dict(checkpoint)
print("test_out1: ", model_test(Input).data)	# 查看预训练模型加载后的输出

model_test2 = torch.load(model_path_model)	# 直接加载整个模型
print("test_out1: ", model_test2(Input).data)	# 查看预训练模型加载后的输出

对应的输出结果如下:

代码语言:javascript
复制
test_out:   tensor([[[[0.1190]]]])  # 网络刚开始的输出结果
test_out1:  tensor([[[[0.2558]]]])	# 加载参数后的网络输出
test_out2:  tensor([[[[0.2558]]]])  # 加载整个模型后的网络输出

从结果中可以看出,这两种方式加载网络模型的效果是一样的,但是只存储参数的模型所占空间为 2731字节,整个模型所占的空间为4071字节,所以一般建议采取第一种方法。

3.2 模型参数分开存储
代码语言:javascript
复制
model_path_dict2 = './ckpt_dict2.pth'	# 模型的存储路径
torch.save({
    'net1':model.head.state_dict(),
    'net2':model.tail.state_dict(),
     }, model_path_dict2)	# 将模型的head和tail模块分开存储
model3 = Network()	# 实例化一个新的网络
print("test_out: ", model3(Input).data)	# 测试一下原始输出

checkpoint = torch.load(model_path_dict2)
model3.head.load_state_dict(checkpoint['net1'])	# 给不同的模块分别加载不同的模型
model3.tail.load_state_dict(checkpoint['net2'])	
print("test_out: ", model3(Input).data)	#测试一下最后的输出
代码语言:javascript
复制
test_out:  tensor([[[[-0.1870]]]])
test_out:  tensor([[[[0.2558]]]])
4、加载模型的部分参数

很多时候我们在训练过程中或多或少都会遇到如下问题:

  1. 已经有了与网络匹配的预训练模型,根据情况需要在网络中添加一个小模块,但是还想利用之前的与训练模型
  2. 虽然用的是同一个网络结构,但是由于定义的方法不一样,导致与训练模型的 key 对应不上

在这些情况下,上述加载模型的方式不能很好地解决这些问题,因此在加载模型时需要更精细的控制才能满足我们的要求。首先我们要先了解一下网络加载模型的实质,其实网络和模型都是按照字典的格式进行存储的,如下所示:

代码语言:javascript
复制
net_dic = model.state_dict()	# 加载网络的字典
for key, value in net_dic.items():	# 显示网络的 key value 值
    print(key)
    print(value)
for key, value in checkpoint.items():	# 显示模型的 key value 值
    print(key)
    print(value)

输出结果如下:

代码语言:javascript
复制
"""
这是网络的key-value
"""
head.n.weight
tensor([[[[-0.2744,  0.2048, -0.0635],
          [-0.1417,  0.2827, -0.2909],
          [ 0.0396, -0.0686,  0.2342]]],
          ...])
head.n.bias
tensor([-0.2389,  0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
          [ 0.1010, -0.1735, -0.0215],
          [ 0.0153,  0.1298, -0.2054]]
          ...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598,  0.2197,  0.1340],
          [-0.1290,  0.1500, -0.1595],
          [-0.1066,  0.0536,  0.1065]],
          ...]])
tail.n.1.bias
tensor([0.0029])
代码语言:javascript
复制
"""
这是与训练模型的key-value
"""
head.n.weight
tensor([[[[-0.2744,  0.2048, -0.0635],
          [-0.1417,  0.2827, -0.2909],
          [ 0.0396, -0.0686,  0.2342]]],
       ...])
head.n.bias
tensor([-0.2389,  0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
          [ 0.1010, -0.1735, -0.0215],
          [ 0.0153,  0.1298, -0.2054]],
        ...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598,  0.2197,  0.1340],
          [-0.1290,  0.1500, -0.1595],
          [-0.1066,  0.0536,  0.1065]],
			...]])
tail.n.1.bias
tensor([0.0029])

因此模型加载的实质可以总结为:找到网络与模型相同的key,将模型对应的参数填入到网络中去。因此若要解决上述问题,只需要在加载模型参数时,进行 if-else 判断进行选择特定的网络层或者筛选特定的模型参数。所以 3.1节中加载模型参数可以写成:

代码语言:javascript
复制
checkpoint = torch.load(model_path_dict)	# 采用加载参数的方式加载与训练模型
model_stic = model.state_dict()	# 提取网络的字典
state_dic = {k:v for k,v in checkpoint.items() if k in model_stic.keys()}	# 找出待加载模型中与网络key一样的参数
model_stic.update(state_dic) # 更新网络参数
print("test_out1: ", model_test(Input).data)	# 查看预训练模型加载后的输出
5、冻结模型的部分参数

在训练网络的时候,有的时候不一定需要网络的每个结构都按照同一个学习率更新,或者有的模块干脆不更新,因此这就需要冻结部分模型参数的梯度,但是又不能截断反向传播的梯度流,不然就会导致网络无法正常训练。

5.1 方法一:requires_grad = false
代码语言:javascript
复制
for name, para in model.named_parameters():
    if 'tail' in name:
        para.requires_grad = False	# 将 tail 模块的梯度更新关闭,即冻结tail的参数
  
for para in model.parameters():	# 在训练前输出一下网络参数,与训练后进行对比
    print(para)
    
for index in range(Iter_num):
    Output = model(Input)
    loss = criterion(Output, Label)
    loss.backward()
    optimizer.step()
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))
    
for para in model.parameters():	# 输出训练后的模型参数
    print(para)

训练前的网络的部分参数:

代码语言:javascript
复制
Parameter containing:
tensor([[[[ 0.1211,  0.2768, -0.0686],
          [ 0.2494, -0.0537,  0.0353],
          [ 0.3018, -0.3092, -0.2098]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
				...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
				...]])
Parameter containing:
tensor([0.1177])

训练后网络的参数:

代码语言:javascript
复制
Parameter containing:
tensor([[[[ 0.1256,  0.2754, -0.0720],
          [ 0.2429, -0.0717,  0.0461],
          [ 0.2887, -0.3248, -0.2124]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
         ...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]])
Parameter containing:
tensor([0.1177])

通过对比可以发现,网络只更新了 head 层的参数,被冻结的 tail 层参数并没有更新。

5.2 从优化器中设置更新的网络层
代码语言:javascript
复制
import torch.optim as optim
optimizer = optim.SGD(model.head.parameters(), lr = 0.001)	# 在优化器中只填入head层的参数
for para in model.parameters():	# 在训练前输出一下网络参数,与训练后进行对比
    print(para)
    
for index in range(Iter_num):
    Output = model(Input)
    loss = criterion(Output, Label)
    loss.backward()
    optimizer.step()
    print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))
    
for para in model.parameters():	# 输出训练后的模型参数
    print(para)

训练前的网络的部分参数:

代码语言:javascript
复制
Parameter containing:
tensor([[[[ 0.1211,  0.2768, -0.0686],
          [ 0.2494, -0.0537,  0.0353],
          [ 0.3018, -0.3092, -0.2098]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
					...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

训练后的网络的部分参数:

代码语言:javascript
复制
Parameter containing:
tensor([[[[ 0.1256,  0.2754, -0.0720],
          [ 0.2429, -0.0717,  0.0461],
          [ 0.2887, -0.3248, -0.2124]]],
					...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208,  0.0399],
          [-0.2201, -0.1703, -0.1215],
          [ 0.1487,  0.1382, -0.1045]],
					...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
          [ 0.2094,  0.1792, -0.2351],
          [ 0.0441, -0.0397, -0.0388]],
					...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

对比这两种方法都能够实现网络某一层参数的冻结而不影响其它层的梯度更新,但是仔细观察发现方法一中不更新参数的网络层的 requires_grad = False,而方法二中所有层的 requires_grad = True。由于个人知识水平有限,难免有错误的地方,还请不吝指正,相互学习,共同进步。

本文系转载,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文系转载前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、本文通用的网络模型
  • 2、对模型进行训练得到权重
  • 3、模型存储
    • 3.1 模型参数一起存储与加载
      • 3.2 模型参数分开存储
      • 4、加载模型的部分参数
      • 5、冻结模型的部分参数
        • 5.1 方法一:requires_grad = false
          • 5.2 从优化器中设置更新的网络层
          相关产品与服务
          对象存储
          对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档