当前位置:网站首页>深度学习中的学习率调整策略(1)
深度学习中的学习率调整策略(1)
2022-08-10 05:30:00 【公众号学一点会一点】

学习率(LearningRate, LR/lr)是深度学习中很重要的一个超参数了。其公式:
也就是说它是在训练过程中更新网络权重的一个调整因子,为什么说其重要呢?简单说:
学习率太大,梯度容易爆炸,loss的振幅较大,模型难以收敛; 学习率太小,容易过拟合,也容易陷入“局部最优”点;
因此选择一个合适的学习率是非常重要的。 对于新手来说,一般可能是看网上的经验或者开源代码选择一个差不多的lr(比如0.1-0.001之间)。
但是,真正用自己的数据来进行模型调试的时候就会发现,学习率也是一个非常重要的超参数,且不是那么好确定的。。。
理解了太上老君炼丹的不易。

不过还好,有大佬们想到了动态调整学习率的方法,其原理也非常简单:根据某种策略,在模型训练的过程中动态地对学习率进行调整,一般是按照某种策略进行衰减(可以想象当快要到达谷底或者山峰的时候就会放慢步伐)。
学习率调整策略
学习率调整策略在pytorch的torch.optim模块下,称其为scheduler,所以也可以说它仍然是优化器的一部分。 学习率调整一般是在优化器进行更新之后进行调整,其示例代码(来自官网):
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
注意看上面的代码,其在epoch的循环中,而不是在最内层的batch循环中,因为一般是在训练了几个epoch之后调整学习率,如果是在batch中,lr更新的太快了;
对于学习率的调整,Pytorch中提供了如下14种方法(具体见参考链接【3】):
lr_scheduler.LambdaLR lr_scheduler.MultiplicativeLR lr_scheduler.StepLR lr_scheduler.MultiStepLR lr_scheduler.ConstantLR lr_scheduler.LinearLR lr_scheduler.ExponentialLR lr_scheduler.CosineAnnealingLR lr_scheduler.ChainedScheduler lr_scheduler.SequentialLR lr_scheduler.ReduceLROnPlateau lr_scheduler.CyclicLR lr_scheduler.OneCycleLR lr_scheduler.CosineAnnealingWarmRestarts
具体每种方法的用法后面再讲,我们先看下一个例子的:
model = torchvision.models.AlexNet(num_classes=2)
optimizer = optim.Adam(model.parameters(),lr=0.01)
scheduler = optim.lr_scheduler.LinearLR(optimizer,start_factor=0.1, total_iters=100)
for epoch in range(100):
print(f"当前学习率:{optimizer.param_groups[0]['lr']}")
optimizer.step()
scheduler.step()
上面的例子使用了Adam作为优化器,然后用线性的方式在训练的过程中更新学习率;
其学习率的变化如下:

可以看到LinearLR的策略就是设定起始的学习率(优化器中的学习率 start_factor)和终止的学习率(默认是优化器中的学习率end_factor,end_factor默认为1.0),然后 按照total_iters把起始学习率和终止学习率确定的区间进行均分,然后每个epoch更新一次。 需要注意的是,当达到设定的终止学习率之后,即便还没训练完,学习率也不会再更新了。
那如果我们设置了不合适的参数,导致学习率很快就更新到头了,比如10个epoch就更新完了,但是训练一共是100个epoch怎么办?不要慌,Pytorch中的学习率更新可以进行链式调度,也就是说可以同时使用多个学习率更新策略!示例:
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
也就是说我们可以同时使用多个策略来更新学习率,比如每训练多个epoch更新一次+loss不变化的时候再主动更新,等等。。
下篇文章详解。

参考
【1】https://zhuanlan.zhihu.com/p/41681558
【2】https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
【3】https://pytorch.org/docs/stable/optim.html
【4】https://hasty.ai/content-hub/mp-wiki/scheduler/cycliclr
本文由 mdnice 多平台发布
边栏推荐
- [Thesis Notes] Prototypical Contrast Adaptation for Domain Adaptive Semantic Segmentation
- 细数国产接口协作平台的六把武器!
- 基于Qiskit——《量子计算编程实战》读书笔记(七)
- pytorch框架学习(9)torchvision.transform
- strongest brain (1)
- Pony语言学习(一):环境配置(续)
- 大咖说·对话生态|当Confluent遇见云:实时流动的数据更有价值
- CORS跨域资源共享漏洞的原理与挖掘方法
- Pony语言学习(七)——表达式(Expressions)语法(单篇向)
- 【论文笔记1】小样本分类
猜你喜欢

每周推荐短视频:探索AI的应用边界

速刷正则表达式一周目(上)

MySQL simple tutorial

pytorch框架学习(9)torchvision.transform

scikit-learn机器学习 读书笔记(二)

暑期学前作业

一篇文章带你搞懂什么是幂等性问题?如何解决幂等性问题?

Error when installing oracle rac 11g and executing root.sh

图纸怎么折?(A0,A1,A2,A3の图纸如何折成A4大小)

Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
随机推荐
Zhongang Mining: Strong downstream demand for fluorite
Stacks and Queues | Valid parentheses, delete all adjacent elements in a string, reverse Polish expression evaluation, maximum sliding window, top K high frequency elements | leecode brush questions
速刷正则表达式一周目(上)
常用工具系列 - 常用正则表达式
Consulting cdc 2.0 for mysql does not execute flush with read lock. How to ensure bin
SSM框架整合实例
【写下自用】每次都忘记如何train?记录如何训练自己的yolov5
conda创建虚拟环境方法和pqi使用国内镜像源安装第三方库的方法教程
strongest brain (1)
FPGA工程师面试试题集锦31~40
Depth of carding: prevent model fitting method
aliases node analysis
Nexus_Warehouse Type
pytorch框架学习(9)torchvision.transform
Why are negative numbers in binary represented in two's complement form - binary addition and subtraction
基于Qiskit——《量子计算编程实战》读书笔记(四)
Flutter development: error The following assertion was thrown resolving an image codec: Solution for Unable to...
看了几十篇轻量化目标检测论文扫盲做的摘抄笔记
Stacks and Queues | Implementing Queues with Stacks | Implementing Stacks with Queues | Basic Theory and Code Principles
CSDN Markdown 之我见代码块 | CSDN编辑器测评