当前位置:网站首页>深度学习笔记 —— 微调
深度学习笔记 —— 微调
2022-04-23 04:53:00 【Whisper_yl】
通常希望能在很大的数据集上训练好的模型能够帮助提升精度。
一部分做特征抽取,一部分做线性分类。
核心思想:在源数据集(通常是比较大的数据集)上训练的模型,我们觉得可以把做特征提取那块拿来用。(越底层的特征越为通用)
在自己的数据集上训练的时候,使用一个与pre-train一样架构的模型,做除了最后一层的初始化的时候,不再是随机的初始化,而是使用pre-train训练好的weight(可能与最终的结果很像了,总好于随机的初始化),等价于把特征提取模块复制过来作为我初始化的模型,使得我一开始就能做到还不错的特征表达。(最后一层标号不一样,所以可以随机初始化)
已经跟最优解比较接近了,所以使用更小的学习率和更少的迭代次数。
在微调的时候不去改变底层的类别的权重,将其固定住,不再变化那些参数,模型的复杂度也就降低了。
在数据集很小的情况下,如果觉得全部参数参与训练容易过拟合,可以考虑固定住底部一些层的参数,不参与更新。
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
# save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
plt.show()
# 使用RGB通道的均值和标准差,以标准化每个通道
# 因为在ImageNet上训练的模型做了这样的处理,所以此处做同样的处理
normalize = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize])
# 定义和初始化模型
# pretrained=True,说明不仅把模型的定义拿过来,同样把训练好的参数也拿过来
pretrained_net = torchvision.models.resnet18(pretrained=True)
print(pretrained_net.fc)
finetune_net = torchvision.models.resnet18(pretrained=True)
# 最后的输出层随机初始化成一个线性层,此处是一个二分类问题
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
# 只对最后一层的weight做初始化
nn.init.xavier_uniform_(finetune_net.fc.weight)
# 微调模型
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
param_group=True):
train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train'), transform=train_augs),
batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'test'), transform=test_augs),
batch_size=batch_size)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss(reduction="none")
# 把除了最后一层的所有层都拿出来,用较小的学习率;最后一层的学习率乘以10,希望其学习更快
if param_group:
params_1x = [param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{'params': params_1x},
{'params': net.fc.parameters(),
'lr': learning_rate * 10}],
lr=learning_rate, weight_decay=0.001)
else:
trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=0.001)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
devices)
train_fine_tuning(finetune_net, 5e-5)
# 进行对比,不设置pretrained
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
# 采用较大的学习率
train_fine_tuning(scratch_net, 5e-4, param_group=False)
版权声明
本文为[Whisper_yl]所创,转载请带上原文链接,感谢
https://blog.csdn.net/LightInDarkness/article/details/124259233
边栏推荐
- Learning Android II from scratch - activity
- vscode ipynb文件没有代码高亮和代码补全解决方法
- js 判斷數字字符串中是否含有字符
- 负载均衡简介
- Windows remote connection to redis
- Leetcode - > 1 sum of two numbers
- Innovation training (IV) preliminary preparation - server
- PHP+MySQL 制作留言板
- Eight misunderstandings that should be avoided in data visualization
- Leetcode 1547: minimum cost of cutting sticks
猜你喜欢
/etc/bash_ completion. D directory function (the user logs in and executes the script under the directory immediately)
Innovation training (V) configuration information
Simply drag objects to the item bar
[database] MySQL single table query
Record the ThreadPoolExecutor main thread waiting for sub threads
Practice and exploration of knowledge map visualization technology in meituan
Learning Android II from scratch - activity
COM in wine (2) -- basic code analysis
redis数据类型有哪些
Innovation training (IX) integration
随机推荐
[database] MySQL single table query
Code007 -- determine whether the string in parentheses matches
Details related to fingerprint payment
Leetcode005 -- delete duplicate elements in the array in place
Innovation training (10)
Wechat payment function
Pixel mobile phone brick rescue tutorial
Field injection is not recommended using @ Autowired
Flink's important basics
简单的拖拽物体到物品栏
Record the ThreadPoolExecutor main thread waiting for sub threads
The last day of 2021 is the year of harvest.
Spark optimization
Record the blind injection script
C language: spoof games
C# List字段排序含有数字和字符
Implementation of switching windows and capturing data in selenium mode
redis和mysql区别
Solve valueerror: argument must be a deny tensor: 0 - got shape [198602], but wanted [198602, 16]
The unity camera rotates with the mouse