当前位置:网站首页>Pytorch学习记录(十二):学习率衰减+正则化
Pytorch学习记录(十二):学习率衰减+正则化
2022-04-23 05:43:00 【左小田^O^】
学习率衰减
对于基于一阶梯度进行优化的方法而言,开始的时候更新的幅度是比较大的,也就是说开始的学习率可以设置大一点,但是当训练集的 loss 下降到一定程度之后,,使用这个太大的学习率就会导致 loss 一直来回震荡,比如
这个时候就需要对学习率进行衰减已达到 loss 的充分下降,而是用学习率衰减的办法能够解决这个矛盾,学习率衰减就是随着训练的进行不断的减小学习率。
在 pytorch 中学习率衰减非常方便,使用 torch.optim.lr_scheduler
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
''' 这里我们定义好了模型和优化器,可以通过 optimizer.param_groups 来得到所有的参数组和其对应的属性, 参数组是什么意思呢?就是我们可以将模型的参数分成几个组,每个组定义一个学习率,这里比较复杂,一般来讲如果不做特别修改,就只有一个参数组 这个参数组是一个字典,里面有很多属性,比如学习率,权重衰减等等,我们可以访问以下 '''
# 访问学习率,权重衰减
print('learning rate: {}'.format(optimizer.param_groups[0]['lr']))
print('weight decay: {}'.format(optimizer.param_groups[0]['weight_decay']))
# 所以我们可以通过修改这个属性来改变我们训练过程中的学习率,非常简单
optimizer.param_groups[0]['lr'] = 1e-5
# 为了防止有多个参数组,我们可以使用一个循环
for param_group in optimizer.param_groups:
param_group['lr'] = 1e-1
正则化

torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4) 就可以了,这个 weight_decay 系数就是上面公式中的 𝜆 ,非常方便
注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合,如果太小,那么正则项这个部分基本没有贡献,所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 1e-4 或者 1e-3
def data_tf(x):
im_aug = tfs.Compose([
tfs.Resize(96),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # 增加正则项
criterion = nn.CrossEntropyLoss()
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/120940597
边栏推荐
- 创建二叉树
- 多个一维数组拆分合并为二维数组
- What is JSON? First acquaintance with JSON
- Getting started with JDBC \ getting a database connection \ using Preparedstatement
- Fletter next generation graphics renderer impaller
- 域内用户访问域外samba服务器用户名密码错误
- Hotkeys, interface visualization configuration (interface interaction)
- Solid contract DoS attack
- xxl-job采坑指南xxl-rpc remoting error(connect timed out)
- EditorConfig
猜你喜欢

Excel sets row and column colors according to cell contents

软件架构设计——软件架构风格

Common protocols of OSI layer

‘EddiesObservations‘ object has no attribute ‘filled‘

Batch import of orange single micro service

delete和truncate

Anaconda

Breadth first search topics (BFS)

开发环境 EAS登录 license 许可修改

C language - Spoof shutdown applet
随机推荐
protected( 被 protected 修饰的成员对于本包和其子类可见)
创建二叉树
SQL基础:初识数据库与SQL-安装与基本介绍等—阿里云天池
Isosceles triangle - the 9th Lanqiao provincial competition - group C
创建企业邮箱账户命令
多线程与高并发(1)——线程的基本知识(实现,常用方法,状态)
mysql如何将存储的秒转换为日期
Strategies to improve Facebook's touch rate and interaction rate | intelligent customer service helps you grasp users' hearts
xxl-job采坑指南xxl-rpc remoting error(connect timed out)
容器
JVM series (3) -- memory allocation and recycling strategy
RedHat6之smb服务访问速度慢解决办法记录
一文读懂当前常用的加密技术体系(对称、非对称、信息摘要、数字签名、数字证书、公钥体系)
2 - software design principles
SQL注入
PyEMD安装及简单使用
图解HashCode存在的意义
多线程与高并发(3)——synchronized原理
热键,界面可视化配置(界面交互)
excel获取两列数据的差异数据