当前位置:网站首页>Pytorch learning record (XII): learning rate attenuation + regularization
Pytorch learning record (XII): learning rate attenuation + regularization
2022-04-23 05:51:00 【Zuo Xiaotian ^ o^】
Learning rate decline
For the optimization method based on one step degree , At the beginning, the update range is relatively large , In other words, the initial learning rate can be set a little larger , But when the training set loss After falling to a certain extent ,, Using this too much learning rate will lead to loss Keep shaking back and forth , such as
At this time, we need to attenuate the learning rate, which has reached loss Full decline of , But using the method of learning rate attenuation can solve this contradiction , The decline of learning rate is the continuous reduction of learning rate with the progress of training .
stay pytorch It is very convenient to reduce the learning rate in middle school , Use torch.optim.lr_scheduler
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
''' Here we define the model and optimizer , Can pass optimizer.param_groups To get all parameter groups and their corresponding attributes , What does parameter group mean ? That is, we can divide the parameters of the model into several groups , Each group defines a learning rate , It's a little bit more complicated here , Generally speaking, if no special modification is made , There is only one parameter group This parameter group is a dictionary , It has a lot of properties in it , For example, learning rate , Weight falloff, etc , We can visit the following '''
# Visit learning rate , Weight falloff
print('learning rate: {}'.format(optimizer.param_groups[0]['lr']))
print('weight decay: {}'.format(optimizer.param_groups[0]['weight_decay']))
# So we can change the learning rate in the training process by modifying this attribute , It's simple
optimizer.param_groups[0]['lr'] = 1e-5
# To prevent multiple parameter groups , We can use a loop
for param_group in optimizer.param_groups:
param_group['lr'] = 1e-1
Regularization

torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4) That's all right. , This weight_decay The coefficient is in the above formula 𝜆 , Very convenient
It is important to note the size of the coefficients of the regular term , If it is too big , It will greatly inhibit the update of parameters , It leads to an under fitting , If it's too small , So this part of the regular term basically doesn't contribute , Therefore, it is very important to choose an appropriate weight attenuation coefficient , This needs to be tried according to the specific situation , Preliminary attempts can use 1e-4 perhaps 1e-3
def data_tf(x):
im_aug = tfs.Compose([
tfs.Resize(96),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # Add regular terms
criterion = nn.CrossEntropyLoss()
版权声明
本文为[Zuo Xiaotian ^ o^]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543243970.html
边栏推荐
- Record a project experience and technologies encountered in the project
- 创建企业邮箱账户命令
- 线性规划问题中可行解,基本解和基本可行解有什么区别?
- 实操—Nacos安装与配置
- 自定义异常类
- ES6之解构函数
- 实体中list属性为空或者null,设置为空数组
- Error 2003 (HY000) when Windows connects MySQL: can't connect to MySQL server on 'localhost' (10061)
- PHP处理json_decode()解析JSON.stringify
- idea插件---背景播放歌曲
猜你喜欢
随机推荐
Common status codes
MDN文档里面入参写法中括号‘[]‘的作用
protected( 被 protected 修饰的成员对于本包和其子类可见)
MySQL realizes master-slave replication / master-slave synchronization
开发环境 EAS登录 license 许可修改
POI generates excel and inserts pictures
Multithreading and high concurrency (1) -- basic knowledge of threads (implementation, common methods, state)
自定义异常类
JVM family (4) -- memory overflow (OOM)
多个一维数组拆分合并为二维数组
Batch import of orange single micro service
mysql中duplicate key update
MySQL创建oracle练习表
第36期《AtCoder Beginner Contest 248 打比赛总结》
MySQL lock mechanism
热键,界面可视化配置(界面交互)
Record a project experience and technologies encountered in the project
No.1.#_ 6 Navicat shortcuts
Duplicate key update in MySQL
Flutter nouvelle génération de rendu graphique Impeller









