当前位置:网站首页>Deep learning notes - fine tuning
Deep learning notes - fine tuning
2022-04-23 04:57:00 【Whisper_ yl】
It is usually hoped that models that can be trained on large data sets can help improve accuracy .
One part is feature extraction , One part is linear classification .
The core idea : In the source dataset ( Usually a relatively large data set ) On the training model , We think we can use the part for feature extraction .( The more general the underlying features are )
When training on your own dataset , Use one with pre-train Same architecture model , Except for the initialization of the last layer , No longer random initialization , But use pre-train Well trained weight( It may be very similar to the final result , Better than random initialization ), It is equivalent to copying the feature extraction module as my initialized model , So that I can do good feature expression at the beginning .( The label of the last layer is different , So it can be initialized randomly )
It is close to the optimal solution , So use less learning rate and fewer iterations .
When fine tuning, do not change the weight of the underlying category , Secure it , Don't change those parameters , The complexity of the model is reduced .
In the case of very small data sets , If you feel that all parameters participate in training, it is easy to over fit , You can consider fixing the parameters of some layers at the bottom , Do not participate in updates .
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
# save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
plt.show()
# Use RGB The mean and standard deviation of the channel , To standardize each channel
# Because in ImageNet The training model did this , So do the same here
normalize = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize])
# Define and initialize models
# pretrained=True, Explain not only the definition of the model , Also take the trained parameters
pretrained_net = torchvision.models.resnet18(pretrained=True)
print(pretrained_net.fc)
finetune_net = torchvision.models.resnet18(pretrained=True)
# The final output layer is randomly initialized to a linear layer , Here is a binary classification problem
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
# Only for the last layer weight Do initialization
nn.init.xavier_uniform_(finetune_net.fc.weight)
# Fine tune the model
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
param_group=True):
train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train'), transform=train_augs),
batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'test'), transform=test_augs),
batch_size=batch_size)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss(reduction="none")
# Take out all the layers except the last one , With a smaller learning rate ; The learning rate of the last layer is multiplied by 10, I hope they can learn faster
if param_group:
params_1x = [param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{'params': params_1x},
{'params': net.fc.parameters(),
'lr': learning_rate * 10}],
lr=learning_rate, weight_decay=0.001)
else:
trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=0.001)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
devices)
train_fine_tuning(finetune_net, 5e-5)
# Contrast , Not set up pretrained
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
# Adopt a higher learning rate
train_fine_tuning(scratch_net, 5e-4, param_group=False)
版权声明
本文为[Whisper_ yl]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230453143799.html
边栏推荐
- Teach you how to build the ruoyi system by Tencent cloud
- Com alibaba. Common methods of fastjson
- Wine (COM) - basic concept
- Sword finger offer: symmetric binary tree (recursive iteration leetcode 101)
- Innovation training (VII) FBV view & CBV view
- The unity camera rotates with the mouse
- Manually write smart pointer shared_ PTR function
- Use model load_ state_ Attributeerror appears when dict(): 'STR' object has no attribute 'copy‘
- 和谐宿舍(线性dp / 区间dp)
- Innovation training (XI) airline ticket crawling company information
猜你喜欢
[database] MySQL single table query
[WinUI3]編寫一個仿Explorer文件管理器
Leetcode 1547: minimum cost of cutting sticks
Wechat payment function
Installation and deployment of Flink and wordcount test
深度学习笔记 —— 微调
深度学习笔记 —— 物体检测和数据集 + 锚框
Teach you how to build the ruoyi system by Tencent cloud
深度学习笔记 —— 数据增广
Flink's important basics
随机推荐
Arduino UNO r3+LCD1602+DHT11
List< Map> Replication: light copy and deep copy
View, modify and delete [database] table
跨境电商 | Facebook 和 Instagram:哪个社交媒体更适合你?
Making message board with PHP + MySQL
Record the ThreadPoolExecutor main thread waiting for sub threads
Knowledge points sorting: ES6
Unity RawImage背景无缝连接移动
Learning Android II from scratch - activity
js 判断数字字符串中是否含有字符
Sword finger offer: push in and pop-up sequence of stack
信息学奥赛一本通 1955:【11NOIP普及组】瑞士轮 | OpenJudge 4.1 4363:瑞士轮 | 洛谷 P1309 [NOIP2011 普及组] 瑞士轮
scp命令详解
Sword finger offer: symmetric binary tree (recursive iteration leetcode 101)
JS détermine si la chaîne de nombres contient des caractères
Raspberry pie + opencv + opencv -- face detection ------- environment construction
Sword finger offer: the path with a certain value in the binary tree (backtracking)
Getprop property
The unity camera rotates with the mouse
selenium模式下切换窗口,抓取数据的实现