当前位置:网站首页>Pytorch learning record (XII): learning rate attenuation + regularization
Pytorch learning record (XII): learning rate attenuation + regularization
2022-04-23 05:51:00 【Zuo Xiaotian ^ o^】
Learning rate decline
For the optimization method based on one step degree , At the beginning, the update range is relatively large , In other words, the initial learning rate can be set a little larger , But when the training set loss After falling to a certain extent ,, Using this too much learning rate will lead to loss Keep shaking back and forth , such as
At this time, we need to attenuate the learning rate, which has reached loss Full decline of , But using the method of learning rate attenuation can solve this contradiction , The decline of learning rate is the continuous reduction of learning rate with the progress of training .
stay pytorch It is very convenient to reduce the learning rate in middle school , Use torch.optim.lr_scheduler
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
''' Here we define the model and optimizer , Can pass optimizer.param_groups To get all parameter groups and their corresponding attributes , What does parameter group mean ? That is, we can divide the parameters of the model into several groups , Each group defines a learning rate , It's a little bit more complicated here , Generally speaking, if no special modification is made , There is only one parameter group This parameter group is a dictionary , It has a lot of properties in it , For example, learning rate , Weight falloff, etc , We can visit the following '''
# Visit learning rate , Weight falloff
print('learning rate: {}'.format(optimizer.param_groups[0]['lr']))
print('weight decay: {}'.format(optimizer.param_groups[0]['weight_decay']))
# So we can change the learning rate in the training process by modifying this attribute , It's simple
optimizer.param_groups[0]['lr'] = 1e-5
# To prevent multiple parameter groups , We can use a loop
for param_group in optimizer.param_groups:
param_group['lr'] = 1e-1
Regularization
torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)
That's all right. , This weight_decay
The coefficient is in the above formula 𝜆
, Very convenient
It is important to note the size of the coefficients of the regular term , If it is too big , It will greatly inhibit the update of parameters , It leads to an under fitting , If it's too small , So this part of the regular term basically doesn't contribute , Therefore, it is very important to choose an appropriate weight attenuation coefficient , This needs to be tried according to the specific situation , Preliminary attempts can use 1e-4 perhaps 1e-3
def data_tf(x):
im_aug = tfs.Compose([
tfs.Resize(96),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # Add regular terms
criterion = nn.CrossEntropyLoss()
版权声明
本文为[Zuo Xiaotian ^ o^]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543243970.html
边栏推荐
- PyQy5学习(四):QAbstractButton+QRadioButton+QCheckBox
- poi生成excel,插入图片
- MySQL triggers, stored procedures, stored functions
- 金蝶EAS“总账”系统召唤“反过账”按钮
- freemark中插入图片
- MySQL创建oracle练习表
- Total score of [Huawei machine test] (how to deal with the wrong answer? Go back once to represent one wrong answer)
- MySQL lock mechanism
- Package mall system based on SSM
- 实体中list属性为空或者null,设置为空数组
猜你喜欢
Dva中在effects中获取state的值
Pilotage growth · ingenuity empowerment -- yonmaster developer training and pilotage plan is fully launched
图解HashCode存在的意义
SQL statement simple optimization
Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
Pytorch学习记录(五):反向传播+基于梯度的优化器(SGD,Adagrad,RMSporp,Adam)
字符串(String)笔记
MySQL realizes master-slave replication / master-slave synchronization
Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)
mysql sql优化之Explain
随机推荐
Strategy for improving the conversion rate of independent stations | recovering abandoned users
mysql实现主从复制/主从同步
Common status codes
Latex快速入门
RedHat6之smb服务访问速度慢解决办法记录
域内用户访问域外samba服务器用户名密码错误
Pytorch学习记录(四):参数初始化
创建线程的三种方式
Total score of [Huawei machine test] (how to deal with the wrong answer? Go back once to represent one wrong answer)
MySQL transaction
SQL基础:初识数据库与SQL-安装与基本介绍等—阿里云天池
2-軟件設計原則
freemark中插入图片
Hotkeys, interface visualization configuration (interface interaction)
Split and merge multiple one-dimensional arrays into two-dimensional arrays
TypeScript interface & type 粗略理解
实操—Nacos安装与配置
Map对象 map.get(key)
框架解析1.系统架构简介
PHP处理json_decode()解析JSON.stringify