当前位置:网站首页>PyTorch 10. Learning rate
PyTorch 10. Learning rate
2022-04-23 07:28:00 【DCGJ666】
PyTorch 10. Learning rate
scheduler
scheduler: An optimizer must adjust its learning rate
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
pass
def get_lr(self):
return [base_lr * self.gamma ** (self.last_epoch//self.step_size) for base_lr in self.base_lrs]
def step(self):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
optimizer: Associated optimizer
last_epoch: Record epoch Count
base_lrs: Record the initial learning rate
The main method :
step(): Update next epoch Learning rate of
get_lr(): Calculate next epoch Learning rate of
StepLR
lr_scheduler.StepLR(optimizer, step_size, gamma=0.1,last_epoch=-1)
function : Adjust the learning rate at equal intervals
main parameter :
step_size: Adjust the number of intervals
gamma: Adjustment factor
arrange mode : lr = lr * gamma
MultiStepLR
lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
function : Adjust the learning rate at a given interval
main parameter :
milestones: Set the number of adjustment times ,milestones = [50, 125, 160]
gamma: Adjustment factor
arrange mode : lr = lr *gamma
ExponentialLR
lr_scheduler.ExponentialLR(optimizer, gamma, last-epoch=-1)
function : Adjust the learning rate by exponential decay
main parameter :
gamma: The bottom of the index
CosineAnnealingLR
lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
function : Cosine period adjusted learning rate
main parameter :
T_max: Descent cycle
eta_min: Lower limit of learning rate
arrange mode :
l r t = l r m i n + 1 2 ( l r m a x − l r m i n ) ( 1 + c o s ( T c u r T m a x π ) ) lr_t = lr_{min}+\frac{1}{2}(lr_{max}-lr_{min})(1+cos(\frac{T_{cur}}{T_{max}}\pi)) lrt=lrmin+21(lrmax−lrmin)(1+cos(TmaxTcurπ))
ReduceLRonPlateau
lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel',cooldown=0, min_lr=0, eps=1e-08)
function : Monitoring indicators , When the index no longer changes, adjust
main parameter :
mode: min/max Two modes
factor: Adjustment factor
patience: “ Patience, ”, Accept several times without change
cooldown: “ Cooling time ”, Stop monitoring for a while
verbose: Whether to print the log
min_lr: Lower limit of learning rate
eps: Minimum attenuation of learning rate
Reference resources :
https://zhuanlan.zhihu.com/p/146865009
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343971.html
边栏推荐
- 【点云系列】Pointfilter: Point Cloud Filtering via Encoder-Decoder Modeling
- 画 ArcFace 中的 margin 曲线
- 传输层重要知识(面试,复试,期末)
- Visual studio 2019 installation and use
- 机器学习——模型优化
- Pymysql connection database
- Infrared sensor control switch
- [8] Assertion failed: dims. nbDims == 4 || dims. nbDims == 5
- 多机多卡训练时的错误
- PyTorch 18. torch.backends.cudnn
猜你喜欢
Chapter 1 numpy Foundation
Are realrange and einsum really elegant
Unwind 栈回溯详解
【点云系列】Learning Representations and Generative Models for 3D pointclouds
[Point Cloud Series] SG - Gan: Adversarial Self - attachment GCN for Point Cloud Topological parts Generation
x86架构初探之8086
FATFS FAT32学习小记
MySQL installation and configuration - detailed tutorial
直观理解 torch.nn.Unfold
【点云系列】Multi-view Neural Human Rendering (NHR)
随机推荐
. net encountered failed to decode downloaded font while loading font:
美摄科技推出桌面端专业视频编辑解决方案——美映PC版
Detailed explanation of device tree
多机多卡训练时的错误
【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis
x86架构初探之8086
初探智能指针之std::shared_ptr、std::unique_ptr
AUTOSAR从入门到精通100讲(八十一)-AUTOSAR基础篇之FiM
CMSIS CM3源码注解
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
Machine learning II: logistic regression classification based on Iris data set
FATFS FAT32学习小记
Chapter 8 generative deep learning
商业广场无线对讲系统解决方案
Error in multi machine and multi card training
ARMCC/GCC下的stack protector
【技术规范】:如何写好技术文档?
[point cloud series] a rotation invariant framework for deep point cloud analysis
Unable to determine the device handle for GPU 0000:02:00.0: GPU is lost.
torch.where能否传递梯度