当前位置:网站首页>模型冻结对应层参数freeze
模型冻结对应层参数freeze
2022-08-09 00:25:00 【helloworld_Fly】
目的
在做迁移学习或者自监督学习时,一般先预训练一个模型,再将该模型参数作为目标任务模型的初始化参数,或者直接freeze预训练模型,不再更新其参数。
方法
from collections.abc import Iterable
def set_freeze_by_names(model, layer_names, freeze=True):
if not isinstance(layer_names, Iterable):
layer_names = [layer_names]
for name, child in model.named_children():
if name not in layer_names:
continue
for param in child.parameters():
#print(param.name)
param.requires_grad = not freeze
def freeze_by_names(model, layer_names):
set_freeze_by_names(model, layer_names, True)
def unfreeze_by_names(model, layer_names):
set_freeze_by_names(model, layer_names, False)
def set_freeze_by_idxs(model, idxs, freeze=True):
if not isinstance(idxs, Iterable):
idxs = [idxs]
num_child = len(list(model.children()))
idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
for idx, child in enumerate(model.children()):
if idx not in idxs:
continue
for param in child.parameters():
param.requires_grad = not freeze
def freeze_by_idxs(model, idxs):
set_freeze_by_idxs(model, idxs, True)
def unfreeze_by_idxs(model, idxs):
set_freeze_by_idxs(model, idxs, False)
我的使用
# select params to freeze
# print(all_model_list[:5])
for name, child in self.backbone.named_children():
if name in all_model_list:
for param in child.parameters():
# print(param.name)
param.requires_grad = False
思路
- 读取整个模型
- 获取对应的子模型名字和child
- 将对应child参数冻结(child.parameters() 中param转化为 requires_grad = False)
边栏推荐
- 整流十二 -有效值、平均值、瞬时值、幅值的关系以及相关方法
- GaN图腾柱无桥 Boost PFC(单相)三(预测模型)
- 小G砍树 (换根dp)
- 笔记&代码 | 统计学——基于R(第四版) 第十章 多元线性回归
- 在Ubuntu/Linux环境下使用MySQL:修改数据库sql_mode,可解决“this is incompatible with sql_mode=only_full_group_by”问题
- 菲涅尔反射
- 自考成绩总结
- GaN图腾柱无桥 Boost PFC(单相)五-细节处理
- 【科研-学习-pytorch】7-梯度、激活函数和loss
- 全新Swagger3.0教程,OAS3快速配置指南,实现API接口文档自动化!
猜你喜欢
随机推荐
整流七 - 三相PWM整流器—公式推导篇
GaN图腾柱无桥 Boost PFC(单相)五-细节处理
【科研-学习-pytorch】3-分类问题
全新Swagger3.0教程,OAS3快速配置指南,实现API接口文档自动化!
菲涅尔反射
MySQL5.7设置MySQL/MariaDB 数据库默认编码为utf8mb4
【学习-目标检测】目标检测之——YOLO v3
cmd切换硬盘的命令,从C盘切换到D盘怎么操作
插值拟合——数据处理或预测
GaN图腾柱无桥 Boost PFC(单相)二 (公式推到理解篇)
什么是阿里云服务器系统盘和数据盘?
Pytorch预训练模型和修改——记录
Mysql 根据一个表数据更新另外一个表
2021江苏省赛
GaN图腾柱无桥 Boost PFC(单相)三(预测模型)
mysql排序总结
笔记&代码 | 统计学——基于R(第四版) 第九章一元线性回归
Mysql Workbench uses .sql file to import data into database
千分位数字
【科研-学习-pytorch】4-数据类型、创建、索引和维度变化









