当前位置:网站首页>深度学习笔记 —— 微调
深度学习笔记 —— 微调
2022-04-23 04:53:00 【Whisper_yl】
通常希望能在很大的数据集上训练好的模型能够帮助提升精度。
一部分做特征抽取,一部分做线性分类。
核心思想:在源数据集(通常是比较大的数据集)上训练的模型,我们觉得可以把做特征提取那块拿来用。(越底层的特征越为通用)
在自己的数据集上训练的时候,使用一个与pre-train一样架构的模型,做除了最后一层的初始化的时候,不再是随机的初始化,而是使用pre-train训练好的weight(可能与最终的结果很像了,总好于随机的初始化),等价于把特征提取模块复制过来作为我初始化的模型,使得我一开始就能做到还不错的特征表达。(最后一层标号不一样,所以可以随机初始化)
已经跟最优解比较接近了,所以使用更小的学习率和更少的迭代次数。
在微调的时候不去改变底层的类别的权重,将其固定住,不再变化那些参数,模型的复杂度也就降低了。
在数据集很小的情况下,如果觉得全部参数参与训练容易过拟合,可以考虑固定住底部一些层的参数,不参与更新。
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
# save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
plt.show()
# 使用RGB通道的均值和标准差,以标准化每个通道
# 因为在ImageNet上训练的模型做了这样的处理,所以此处做同样的处理
normalize = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize])
# 定义和初始化模型
# pretrained=True,说明不仅把模型的定义拿过来,同样把训练好的参数也拿过来
pretrained_net = torchvision.models.resnet18(pretrained=True)
print(pretrained_net.fc)
finetune_net = torchvision.models.resnet18(pretrained=True)
# 最后的输出层随机初始化成一个线性层,此处是一个二分类问题
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
# 只对最后一层的weight做初始化
nn.init.xavier_uniform_(finetune_net.fc.weight)
# 微调模型
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
param_group=True):
train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train'), transform=train_augs),
batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'test'), transform=test_augs),
batch_size=batch_size)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss(reduction="none")
# 把除了最后一层的所有层都拿出来,用较小的学习率;最后一层的学习率乘以10,希望其学习更快
if param_group:
params_1x = [param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{'params': params_1x},
{'params': net.fc.parameters(),
'lr': learning_rate * 10}],
lr=learning_rate, weight_decay=0.001)
else:
trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=0.001)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
devices)
train_fine_tuning(finetune_net, 5e-5)
# 进行对比,不设置pretrained
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
# 采用较大的学习率
train_fine_tuning(scratch_net, 5e-4, param_group=False)
版权声明
本文为[Whisper_yl]所创,转载请带上原文链接,感谢
https://blog.csdn.net/LightInDarkness/article/details/124259233
边栏推荐
- Mac 进入mysql终端命令
- Custom switch control
- Recommended scheme of national manufactured electronic components for intelligent electronic scales
- Solutions to the failure of sqoop connection to MySQL
- 泰克示波器DPO3054自校准SPC失败维修
- 独立站运营 | FaceBook营销神器——聊天机器人ManyChat
- Eight misunderstandings that should be avoided in data visualization
- Progress of innovation training (IV)
- 持续集成(CI)/持续交付(CD)如何彻底改变自动化测试
- What is a blocking queue? What is the implementation principle of blocking queue? How to use blocking queue to implement producer consumer model?
猜你喜欢
Eight misunderstandings that should be avoided in data visualization
CLion+OpenCV identify ID number - detect ID number
[WinUI3]编写一个仿Explorer文件管理器
[winui3] write an imitation Explorer file manager
[database] MySQL multi table query (I)
简单的拖拽物体到物品栏
持续集成(CI)/持续交付(CD)如何彻底改变自动化测试
Teach you how to build the ruoyi system by Tencent cloud
Spell it! Two A-level universities and six B-level universities have abolished master's degree programs in software engineering!
Opencv + clion face recognition + face model training
随机推荐
List< Map> Replication: light copy and deep copy
【数据库】表的查看、修改和删除
Leetcode006 -- find the longest common prefix in the string array
The object needs to add additional attributes. There is no need to add attributes in the entity. The required information is returned
Learning Android II from scratch - activity
Windows remote connection to redis
Learning Android from scratch -- baseactivity and activitycollector
Learning Android V from scratch - UI
What is a blocking queue? What is the implementation principle of blocking queue? How to use blocking queue to implement producer consumer model?
What's the difference between error and exception
Leetcode002 -- inverts the numeric portion of a signed integer
KVM error: Failed to connect socket to ‘/var/run/libvirt/libvirt-sock‘
Recommended scheme of national manufactured electronic components
【数据库】MySQL基本操作(基操~)
selenium模式下切换窗口,抓取数据的实现
POI export message list (including pictures)
[database] MySQL multi table query (I)
Excel uses the functions of replacement, sorting and filling to comprehensively sort out financial data
What is the meaning of load balancing
Pixel 5 5g unlocking tutorial (including unlocking BL, installing edxposed and root)