当前位置:网站首页>模型冻结对应层参数freeze
模型冻结对应层参数freeze
2022-08-09 00:25:00 【helloworld_Fly】
目的
在做迁移学习或者自监督学习时,一般先预训练一个模型,再将该模型参数作为目标任务模型的初始化参数,或者直接freeze预训练模型,不再更新其参数。
方法
from collections.abc import Iterable
def set_freeze_by_names(model, layer_names, freeze=True):
if not isinstance(layer_names, Iterable):
layer_names = [layer_names]
for name, child in model.named_children():
if name not in layer_names:
continue
for param in child.parameters():
#print(param.name)
param.requires_grad = not freeze
def freeze_by_names(model, layer_names):
set_freeze_by_names(model, layer_names, True)
def unfreeze_by_names(model, layer_names):
set_freeze_by_names(model, layer_names, False)
def set_freeze_by_idxs(model, idxs, freeze=True):
if not isinstance(idxs, Iterable):
idxs = [idxs]
num_child = len(list(model.children()))
idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
for idx, child in enumerate(model.children()):
if idx not in idxs:
continue
for param in child.parameters():
param.requires_grad = not freeze
def freeze_by_idxs(model, idxs):
set_freeze_by_idxs(model, idxs, True)
def unfreeze_by_idxs(model, idxs):
set_freeze_by_idxs(model, idxs, False)
我的使用
# select params to freeze
# print(all_model_list[:5])
for name, child in self.backbone.named_children():
if name in all_model_list:
for param in child.parameters():
# print(param.name)
param.requires_grad = False
思路
- 读取整个模型
- 获取对应的子模型名字和child
- 将对应child参数冻结(child.parameters() 中param转化为 requires_grad = False)
边栏推荐
猜你喜欢
随机推荐
神经网络基本原理
js中常用方法总结
杭电多校8 补题
【学习-目标检测】目标检测之—FPN+Cascade+Libra
「复盘」面试 BAMT 回来整理 398 道高频面试题,助你拿高薪 offer
[GYCTF2020]Ezsqli-1|SQL注入
逐片元-兰伯特光照模型
灰色预测模型
最优化问题——线性规划模型
整流七 - 三相PWM整流器—公式推导篇
笔记&代码 | 统计学——基于R(第四版) 第二章数据可视化
牛客小白月赛 37 补题
Unity3D小白学习日记(01):如何把物体移动到鼠标点击处
【 StoneDB Class 】 introductory lesson 3: StoneDB installation of compilation
牛客多校8 补题
纹理映射-TextureMapping
A - A + B Problem II
ShadowMap-Example
[Deep Learning] TensorFlow Learning Road 2: Introduction to ANN and TensorFlow Implementation
笔记&代码 | 统计学——基于R(第四版) 第九章一元线性回归