当前位置:网站首页>Pytorch model pruning example tutorial III. multi parameter and global pruning
Pytorch model pruning example tutorial III. multi parameter and global pruning
2022-04-23 07:18:00 【Breeze_】
Catalog
At present, most of the most advanced (SOTA) Although the effect of deep learning technology is good , However, due to the high amount of model parameters and calculation , Difficult to use for actual deployment . And as we all know , Biological neural networks use efficient sparse connections ( Brain biological neural network balabala Everything is sparsely connected ), Consider this , To reduce memory 、 Capacity and hardware consumption , At the same time, without sacrificing the accuracy of model prediction , Deploy a lightweight model on the device , And calculate on private devices to ensure privacy , The best technique to compress the model by reducing the number of parameters is very important .
Sparse neural network can reach the level of dense neural network in prediction accuracy , However, due to the small amount of model parameters , In theory, reasoning will be much faster . Model pruning is a method of training dense neural networks into sparse neural networks .
This article will study the official Example tutorial , This paper introduces how to prune the model through a simple example tutorial , Practice deep learning model compression acceleration .
Related links
Deep learning model compression and acceleration technology ( One ): Parameter pruning
PyTorch Model pruning example tutorial 1 、 Unstructured pruning
PyTorch Model pruning example tutorial 2 、 Structural pruning
PyTorch Model pruning example tutorial 3 、 Multi parameter and global pruning
Through tutorial 1 and tutorial 2 , We can learn how to pass PyTorch Conduct Unstructured and structured The pruning of , generally speaking , We will consider parameter pruning of deeper networks , here , It's more troublesome to prune them through one inspection module , We can use multi parameter and global pruning methods to prune parameters of the same type .
1. Guide pack & Define a simple network
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
''' Build class LeNet The Internet '''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# Single channel image input ,5×5 Nuclear size
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
2. Multi parameter pruning
new_model = LeNet()
for name, module in new_model.named_modules():
# For all Conv2d The parameters of 20% Of L1 Unstructured pruning
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# For all Linear The parameters of 20% Of L1 Unstructured pruning
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # Verify
Output :
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
3. Global pruning
All the methods mentioned above , They are all local pruning methods , We can also use the global pruning method , By deleting the lowest 20% The connection of , Instead of deleting the lowest... In each layer 20% The connection of , in other words , There may be different deletion percentages between layers .
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
print(
" sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
" sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
" sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
" sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
" sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
" Global sparsity : {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
Output
sparsity in conv1.weight: 8.00%
sparsity in conv2.weight: 9.33%
sparsity in fc1.weight: 22.07%
sparsity in fc2.weight: 12.20%
sparsity in fc3.weight: 11.31%
Global sparsity : 20.00%
4. summary
This example first builds a class LeNet A network model , For multi parameter pruning , We use .named_modules() Traversed all layers , And make use of isinstance() Method to determine whether it is Conv2d or Linear structure , In this way, the same type of pruning can be carried out for the same structural parameters . For global pruning , We used .global_unstructured Parameter to prune , You can find , The difference between global pruning and multi parameter pruning lies in , The final sparsity of global pruning is the same as that of multi parameter pruning , However, the global pruning sparsity is not equally sparse for each layer .
The core function method used in this paper :
- .named_modules(), Get the parameter name and structure of the model
- isinstance(), Judge whether the types are consistent
- .global_unstructured, Global pruning method
Reference resources :
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610322733.html
边栏推荐
- What did you do during the internship
- 組件化學習
- 5种方法获取Torch网络模型参数量计算量等信息
- ThreadLocal,看我就够了!
- 【2021年新书推荐】Effortless App Development with Oracle Visual Builder
- ./gradlew: Permission denied
- 如何对多维矩阵进行标准化(基于numpy)
- Handlerthread principle and practical application
- “Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated
- 第8章 生成式深度学习
猜你喜欢

【2021年新书推荐】Kubernetes in Production Best Practices

SSL/TLS应用示例

给女朋友写个微信双开小工具

face_recognition人脸检测

第4章 Pytorch数据处理工具箱
![Android interview Online Economic encyclopedia [constantly updating...]](/img/48/dd1abec83ec0db7d68812f5fa9dcfc.png)
Android interview Online Economic encyclopedia [constantly updating...]
![[2021 book recommendation] kubernetes in production best practices](/img/78/2b5bf03adad5da9a109ea5d4e56b18.png)
[2021 book recommendation] kubernetes in production best practices

1.1 PyTorch和神经网络

Bottom navigation bar based on bottomnavigationview

Component based learning (1) idea and Implementation
随机推荐
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
ThreadLocal,看我就够了!
HandlerThread原理和实际应用
Markdown basic grammar notes
机器学习 三: 基于逻辑回归的分类预测
谷歌AdMob广告学习
Project, how to package
免费使用OriginPro学习版
adb shell常用模拟按键keycode
PyTorch最佳实践和代码编写风格指南
Three methods to realize the rotation of ImageView with its own center as the origin
Keras如何保存、加载Keras模型
第2章 Pytorch基础2
SSL/TLS应用示例
Thanos.sh灭霸脚本,轻松随机删除系统一半的文件
一款png生成webp,gif, apng,同时支持webp,gif, apng转化的工具iSparta
JNI中使用open打开文件是返回-1问题
BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
第8章 生成式深度学习
第5 章 机器学习基础