当前位置:网站首页>Pytorch model pruning example tutorial III. multi parameter and global pruning
Pytorch model pruning example tutorial III. multi parameter and global pruning
2022-04-23 07:18:00 【Breeze_】
Catalog
At present, most of the most advanced (SOTA) Although the effect of deep learning technology is good , However, due to the high amount of model parameters and calculation , Difficult to use for actual deployment . And as we all know , Biological neural networks use efficient sparse connections ( Brain biological neural network balabala Everything is sparsely connected ), Consider this , To reduce memory 、 Capacity and hardware consumption , At the same time, without sacrificing the accuracy of model prediction , Deploy a lightweight model on the device , And calculate on private devices to ensure privacy , The best technique to compress the model by reducing the number of parameters is very important .
Sparse neural network can reach the level of dense neural network in prediction accuracy , However, due to the small amount of model parameters , In theory, reasoning will be much faster . Model pruning is a method of training dense neural networks into sparse neural networks .
This article will study the official Example tutorial , This paper introduces how to prune the model through a simple example tutorial , Practice deep learning model compression acceleration .
Related links
Deep learning model compression and acceleration technology ( One ): Parameter pruning
PyTorch Model pruning example tutorial 1 、 Unstructured pruning
PyTorch Model pruning example tutorial 2 、 Structural pruning
PyTorch Model pruning example tutorial 3 、 Multi parameter and global pruning
Through tutorial 1 and tutorial 2 , We can learn how to pass PyTorch Conduct Unstructured and structured The pruning of , generally speaking , We will consider parameter pruning of deeper networks , here , It's more troublesome to prune them through one inspection module , We can use multi parameter and global pruning methods to prune parameters of the same type .
1. Guide pack & Define a simple network
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
''' Build class LeNet The Internet '''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# Single channel image input ,5×5 Nuclear size
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
2. Multi parameter pruning
new_model = LeNet()
for name, module in new_model.named_modules():
# For all Conv2d The parameters of 20% Of L1 Unstructured pruning
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# For all Linear The parameters of 20% Of L1 Unstructured pruning
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # Verify
Output :
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
3. Global pruning
All the methods mentioned above , They are all local pruning methods , We can also use the global pruning method , By deleting the lowest 20% The connection of , Instead of deleting the lowest... In each layer 20% The connection of , in other words , There may be different deletion percentages between layers .
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
print(
" sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
" sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
" sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
" sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
" sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
" Global sparsity : {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
Output
sparsity in conv1.weight: 8.00%
sparsity in conv2.weight: 9.33%
sparsity in fc1.weight: 22.07%
sparsity in fc2.weight: 12.20%
sparsity in fc3.weight: 11.31%
Global sparsity : 20.00%
4. summary
This example first builds a class LeNet A network model , For multi parameter pruning , We use .named_modules() Traversed all layers , And make use of isinstance() Method to determine whether it is Conv2d or Linear structure , In this way, the same type of pruning can be carried out for the same structural parameters . For global pruning , We used .global_unstructured Parameter to prune , You can find , The difference between global pruning and multi parameter pruning lies in , The final sparsity of global pruning is the same as that of multi parameter pruning , However, the global pruning sparsity is not equally sparse for each layer .
The core function method used in this paper :
- .named_modules(), Get the parameter name and structure of the model
- isinstance(), Judge whether the types are consistent
- .global_unstructured, Global pruning method
Reference resources :
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610322733.html
边栏推荐
- ThreadLocal, just look at me!
- Viewpager2 realizes Gallery effect. After notifydatasetchanged, pagetransformer displays abnormal interface deformation
- [2021 book recommendation] learn winui 3.0
- 【动态规划】不同路径2
- [recommendation for new books in 2021] professional azure SQL managed database administration
- Cancel remote dependency and use local dependency
- JNI中使用open打开文件是返回-1问题
- Component based learning (3) path and group annotations in arouter
- MySQL notes 2_ data sheet
- Itop4412 HDMI display (4.0.3_r1)
猜你喜欢
随机推荐
C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
./gradlew: Permission denied
【动态规划】不同的二叉搜索树
PyTorch中的一些常见数据类型转换方法,与list和np.ndarray的转换方法
webView因证书问题显示一片空白
torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
Miscellaneous learning
[2021 book recommendation] red hat rhcsa 8 cert Guide: ex200
一款png生成webp,gif, apng,同时支持webp,gif, apng转化的工具iSparta
给女朋友写个微信双开小工具
[recommendation for new books in 2021] professional azure SQL managed database administration
Itop4412 HDMI display (4.0.3_r1)
C#新大陆物联网云平台的连接(简易理解版)
Binder mechanism principle
JNI中使用open打开文件是返回-1问题
Kotlin征途之data class [数据类]
Android清除应用缓存
WebRTC ICE candidate里面的raddr和rport表示什么?
杂七杂八的学习
图像分类白盒对抗攻击技术总结
![[2021 book recommendation] learn winui 3.0](/img/1c/ca7e05946613e9eb2b8c24d121c2e1.png)







