当前位置:网站首页>Pytorch学习记录(十二):学习率衰减+正则化
Pytorch学习记录(十二):学习率衰减+正则化
2022-04-23 05:43:00 【左小田^O^】
学习率衰减
对于基于一阶梯度进行优化的方法而言,开始的时候更新的幅度是比较大的,也就是说开始的学习率可以设置大一点,但是当训练集的 loss 下降到一定程度之后,,使用这个太大的学习率就会导致 loss 一直来回震荡,比如
这个时候就需要对学习率进行衰减已达到 loss 的充分下降,而是用学习率衰减的办法能够解决这个矛盾,学习率衰减就是随着训练的进行不断的减小学习率。
在 pytorch 中学习率衰减非常方便,使用 torch.optim.lr_scheduler
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
''' 这里我们定义好了模型和优化器,可以通过 optimizer.param_groups 来得到所有的参数组和其对应的属性, 参数组是什么意思呢?就是我们可以将模型的参数分成几个组,每个组定义一个学习率,这里比较复杂,一般来讲如果不做特别修改,就只有一个参数组 这个参数组是一个字典,里面有很多属性,比如学习率,权重衰减等等,我们可以访问以下 '''
# 访问学习率,权重衰减
print('learning rate: {}'.format(optimizer.param_groups[0]['lr']))
print('weight decay: {}'.format(optimizer.param_groups[0]['weight_decay']))
# 所以我们可以通过修改这个属性来改变我们训练过程中的学习率,非常简单
optimizer.param_groups[0]['lr'] = 1e-5
# 为了防止有多个参数组,我们可以使用一个循环
for param_group in optimizer.param_groups:
param_group['lr'] = 1e-1
正则化

torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4) 就可以了,这个 weight_decay 系数就是上面公式中的 𝜆 ,非常方便
注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合,如果太小,那么正则项这个部分基本没有贡献,所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 1e-4 或者 1e-3
def data_tf(x):
im_aug = tfs.Compose([
tfs.Resize(96),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # 增加正则项
criterion = nn.CrossEntropyLoss()
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/120940597
边栏推荐
- 事实最终变量与最终变量
- freemark中插入图片
- PyEMD安装及简单使用
- Add days to date
- Strategy for improving the conversion rate of independent stations | recovering abandoned users
- Isosceles triangle - the 9th Lanqiao provincial competition - group C
- The role of brackets' [] 'in the parameter writing method in MDN documents
- mysql-触发器、存储过程、存储函数
- Sea Level Anomaly 和 Sea Surface Height Anomaly 的区别
- 字符串(String)笔记
猜你喜欢
随机推荐
引航成长·匠心赋能——YonMaster开发者培训领航计划全面开启
Split and merge multiple one-dimensional arrays into two-dimensional arrays
C语言——恶搞关机小程序
事实最终变量与最终变量
Font shape `OMX/cmex/m/n‘ in size <10.53937> not available (Font) size <10.95> substituted.
一文读懂当前常用的加密技术体系(对称、非对称、信息摘要、数字签名、数字证书、公钥体系)
freemark中插入图片
2-軟件設計原則
Linear sieve method (prime sieve)
PHP处理json_decode()解析JSON.stringify
Breadth first search topics (BFS)
基于thymeleaf实现数据库图片展示到浏览器表格
Error 2003 (HY000) when Windows connects MySQL: can't connect to MySQL server on 'localhost' (10061)
图解HashCode存在的意义
Common status codes
解决报错:ImportError: IProgress not found. Please update jupyter and ipywidgets
Excel sets row and column colors according to cell contents
TypeScript interface & type 粗略理解
Add days to date
Find the number of "blocks" in the matrix (BFS)









