当前位置:网站首页>PyTorch 模型剪枝实例教程三、多参数与全局剪枝
PyTorch 模型剪枝实例教程三、多参数与全局剪枝
2022-04-23 06:11:00 【小风_】
目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要。
稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。
本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。
相关链接
PyTorch模型剪枝实例教程三、多参数与全局剪枝
通过教程一和教程二,我们可以了解如何通过PyTorch进行非结构化和结构化的剪枝,一般而言,我们会考虑将较深的网络进行参数剪枝,此时,通过一个个检查模块诶个给它们剪枝就比较麻烦,我们可以利用多参数和全局剪枝的方法对同类型参数进行剪枝。
1.导包&定义一个简单的网络
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''搭建类LeNet网络'''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 单通道图像输入,5×5核尺寸
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
2.多参数剪枝
new_model = LeNet()
for name, module in new_model.named_modules():
# 对所有Conv2d的参数进行20%的L1非结构化剪枝
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# 对所有Linear的参数进行20%的L1非结构化剪枝
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # 验证一下下
输出:
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
3.全局剪枝
前面所有提到的方法,都是局部剪枝方法,我们还可以使用全局剪枝方法,通过删除整个模型最低的20%的连接,而非删除每个层中最低20%的连接,也就是说,可能会出现层与层之间删除的百分比不一样的情况。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
print(
"稀疏性 in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"稀疏性 in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"稀疏性 in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"稀疏性 in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"稀疏性 in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"全局稀疏性: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
输出
稀疏性 in conv1.weight: 8.00%
稀疏性 in conv2.weight: 9.33%
稀疏性 in fc1.weight: 22.07%
稀疏性 in fc2.weight: 12.20%
稀疏性 in fc3.weight: 11.31%
全局稀疏性: 20.00%
4.总结
本示例首先搭建了一个类LeNet网络模型,为了进行多参数剪枝,我们使用.named_modules()遍历了所有层,并利用isinstance()方法判断是否为Conv2d或Linear结构,以此来对相同结构参数进行同等类型剪枝。为了进行全局剪枝,我们使用了 .global_unstructured参数进行剪枝,可以发现,全局剪枝与多参数剪枝不一样的地方在于,全局剪枝最终的稀疏性虽然和多参数剪枝稀疏性相同,但全局剪枝稀疏性并非对每层均等稀疏的。
本文用到的核心函数方法:
- .named_modules(),获取模型的参数名和结构
- isinstance(),判断类型是否一致
- .global_unstructured,全局剪枝方法
参考:
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/124354155
边栏推荐
- 【2021年新书推荐】Practical Node-RED Programming
- Cancel remote dependency and use local dependency
- useReducer基本用法
- adb shell常用模拟按键keycode
- Viewpager2 realizes Gallery effect. After notifydatasetchanged, pagetransformer displays abnormal interface deformation
- What did you do during the internship
- 数据库的事务
- JNI中使用open打开文件是返回-1问题
- Bottomsheetdialogfragment conflicts with listview recyclerview Scrollview sliding
- Itop4412 cannot display boot animation (4.0.3_r1)
猜你喜欢
![[recommendation for new books in 2021] professional azure SQL managed database administration](/img/f1/b38cce1dc328a5b534011169909127.png)
[recommendation for new books in 2021] professional azure SQL managed database administration
![[2021 book recommendation] practical node red programming](/img/f4/e397c01f1551cd6c59ea4f54c197e6.png)
[2021 book recommendation] practical node red programming

Cancel remote dependency and use local dependency

Bottomsheetdialogfragment conflicts with listview recyclerview Scrollview sliding

【2021年新书推荐】Red Hat Certified Engineer (RHCE) Study Guide

Project, how to package

杂七杂八的学习

Component based learning (3) path and group annotations in arouter

个人博客网站搭建

Binder机制原理
随机推荐
this. getOptions is not a function
sys.dbms_scheduler.create_job创建定时任务(功能更强大丰富)
MySQL笔记5_操作数据
oracle表的约束详解
JVM basics you should know
【2021年新书推荐】Practical Node-RED Programming
Android暴露组件——被忽略的组件安全
[recommendation for new books in 2021] professional azure SQL managed database administration
Itop4412 surfaceflinger (4.4.4_r1)
ThreadLocal,看我就够了!
MySQL notes 4_ Primary key auto_increment
常见的正则表达式
解决::Argument ‘radius‘ is required to be an integer
this.getOptions is not a function
【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
常用UI控件简写名
Bottom navigation bar based on bottomnavigationview
红外传感器控制开关
给女朋友写个微信双开小工具
[sm8150] [pixel4] LCD driver