当前位置:网站首页>Pytorch预训练模型和修改——记录
Pytorch预训练模型和修改——记录
2022-08-09 00:25:00 【helloworld_Fly】
加载模型
一般从torchvision的models中加载常用模型,如alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并提供预训练模型,调用方便。
from torchvision import models
resnet = models.resnet50(pretrain=True)
print(resnet) # 打印网络结构
读取预训练模型
另一种是读取自己预训练模型,而不是使用官方自带。
import torch
resnet18 = models.resnet18(pretrained=False) #pretrained参数默认是False,为了代码清晰,最好还是加上参数赋值.
resnet18.load_state_dict(torch.load(path_params.pkl))
load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。
当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。
整体代码:
print resnet18 #打印的还是网络结构
# 注意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype
pre_dict = resnet18.state_dict() #按键值对将模型参数加载到pre_dict
print for k, v in pre_dict.items(): # 打印模型参数
for k, v in pre_dict.items():
print k #打印模型每层命名
# model是自己定义好的新网络模型,将pretrained_dict和model_dict中命名一致的层加入
# pretrained_dict(包括参数)。
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
模型参数修改
对模型特定层进行修改,一般直接调用对应层名并赋予新的层结构,常用是修改全连接层输出类别或特征数。
import torchvision.models as models
#调用模型
model = models.resnet50(pretrained=True)
#提取fc层中固定的参数
fc_features = model.fc.in_features
#修改类别为9
model.fc = nn.Linear(fc_features, 9)
训练特定层(冻结层)
对于自己定义的网络结构,需要选择特定层进行冻结时,往往需要用到 requires_grad函数来定义是否计算梯度。
count = 0
para_optim = []
for k in model.children():
count += 1
if count > 6: # 6 should be changed properly
for param in k.parameters():
para_optim.append(param)
else:
for param in k.parameters():
param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)#只对特定的层的参数进行优化更新,即选择特定的层进行finetune。
此代码实现了PyTorch中使用预训练的模型初始化网络的一部分参数,主要是:
- 设置选择条件
- 使用children函数得到子模型名,使用parameters得到参数
- 使用requires_grad确定不计算梯度,冻结层
PyTorch的Module.modules()和Module.children()
在PyTorch中,所有的neural network module都是class torch.nn.Module的子类,在Modules中可以包含其它的Modules,以一种树状结构进行嵌套。当需要返回神经网络中的各个模块时,Module.modules()方法返回网络中所有模块的一个iterator,而Module.children()方法返回所有直接子模块的一个iterator。具体而言:
list ( nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules() )
Out[9]:
[Sequential (
(0): Linear (10 -> 20)
(1): ReLU ()
), Linear (10 -> 20), ReLU ()]
In [10]: list( nn.Sequential(nn.Linear(10, 20), nn.ReLU()) .children() )
Out[10]: [Linear (10 -> 20), ReLU ()]
另一个例子:
下载好的模型,可以用下面这段代码看一下模型参数,并且改一下模型。在vgg19.pth同级目录建立一个test.py。
import torch
import torch.nn as nn
import torchvision.models as models
vgg16 = models.vgg16(pretrained=False)
#打印出预训练模型的参数
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
print('vgg16:\n', vgg16)
modified_features = nn.Sequential(*list(vgg16.features.children())[:-1])
# to relu5_3
print('modified_features:\n', modified_features )#打印修改后的模型参数
修改好之后features就可以拿去做Faster-RCNN提取特征用了
边栏推荐
猜你喜欢

控件限制总结

整流十二 -有效值、平均值、瞬时值、幅值的关系以及相关方法

VsCode配置自己喜欢的字体,背景,妈妈再也不担心我写代码枯燥了

最优化问题——线性规划模型

JS data types

mysql 批量修改表及字段字符集

VsCode configures your favorite fonts and backgrounds. Mom no longer worries about my boring code writing.

注意:服务器

在Ubuntu/Linux环境下使用MySQL:修改数据库sql_mode,可解决“this is incompatible with sql_mode=only_full_group_by”问题

GaN图腾柱无桥 Boost PFC(单相)三(预测模型)
随机推荐
SyntaxError line:3546,column:96577,SyntaxError: Unexpected token '...'. Expected a property name.
在Ubuntu/Linux环境下使用MySQL:修改数据库sql_mode,可解决“this is incompatible with sql_mode=only_full_group_by”问题
NOR flash和NAND flash的区别
GaN图腾柱无桥 Boost PFC(单相)五-细节处理
GaN图腾柱无桥 Boost PFC(单相)四(仿真理解)
基本控件属性
Mysql Workbench uses .sql file to import data into database
如何选择云服务器与轻量应用服务器?谈谈自己的看法
数学建模美赛题型分类
笔记| 矩阵分析中需要复习的线性代数知识
如何解决在使用keepAlive后使用grid+echart的页面高度异常的问题
VsCode配置自己喜欢的字体,背景,妈妈再也不担心我写代码枯燥了
mysql排序总结
【全排列】
摘桃子(推式子+优化)
[GYCTF2020]Ezsqli-1|SQL注入
[Deep Learning] TensorFlow Learning Road 2: Introduction to ANN and TensorFlow Implementation
Why software development methodology make you feel bad?
Mysql Workbench导出sql文件出错:Error executing task: ‘ascii‘ codec can‘t decode byte 0xd0 in position 26:
桌面内容整理,用时高效