当前位置:网站首页>【pytorch 模型量化方法总结】
【pytorch 模型量化方法总结】
2022-08-06 20:33:00 【网络星空(luoc)】
文章目录
后端:x86、arm移动嵌入式平台;
对应参数:‘fbgemm’ 、 ‘qnnpack’
命令行:torch.quantization.get_default_qconfig(‘fbgemm’)
1.动态量化代码示例:
import torch
# define a floating point model
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.fc(x)
return x
# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
model_fp32, # the original model
{
torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8) # the target dtype for quantized weights
# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
适用于 Linear、LSTM、RNN等层;
权重直接量化;bias和激活函数 在推理过程中动态量化;
2.静态量化示例:
import torch
# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)
# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
1.静态量化需要在模型起始和结束位置定义quant和dequant接口;
2.配置好后端
3.融合的层声明;一般是conv+relu;或者是conv+bn+relu;
3.准备量化
4.配置量化的推理数据集(一般对应于你的训练任务)
5.量化模型转换;此处转换为int8精度;
6.验证量化后模型;
订阅代码可关注:https://github.com/oyjGithub
边栏推荐
- Data too long for column ‘xxx‘ at row 1
- Ali's second side: How to perform performance tuning with sudden increase in interface traffic?
- Kotlin 协程 - 协程启动模式 CoroutineStart
- 分布式理论
- 删除数据库
- 如何检查导入数据中有重复项
- R语言使用oneway.test函数执行单因素方差分析(One-Way ANOVA)、使用aov函数执行单因素方差分析(aov函数默认组间方差相同)
- JDY-16 蓝牙4.2模块串口测试方法
- ansible——playbook剧本概念及示例
- CI/CD持续集成/持续部署
猜你喜欢
随机推荐
distributed theory
easyexcel 写 文件流存ftp
SDL线程使用
Functions and Objects in the Prototype Chain
R语言拟合ARIMA模型:剔除ARIMA模型中不显著的系数、通过分析系数的置信区间判断系数是否是冗余系数(参数)、以及是否需要被删除
R语言ggplot2可视化:使用ggpubr包的ggdotchart函数可视化分组克利夫兰点图(Cleveland dot)、自定义palette参数设置不同分组的颜色
LeetCode_496_下一个更大元素Ⅰ
如何给WordPress博客网站换个漂亮的字体
云GPU如何安装和启动VNC远程桌面服务?
如何运营外贸独立站
初探基于OSG+OCC的CAD之Netgen体网格划分与显示
STM32MP157A驱动开发 | 02-使用sdmmc接口读写sd卡(热插拔)
R语言ggplot2可视化:基于aes函数中的fill参数自定义绘制分组折线图并添加数据点(散点)、使用scale_fill_manual函数手动添加数据点颜色度量向量(使用十六进制颜色)
Day12:AVL树--平衡二叉树
Pytest学习-yaml+parametrize接口实战
Data too long for column ‘xxx‘ at row 1
什么是鸟撞?该如何设计防鸟撞的建筑?#可持续设计
危险!请马上替换代码中的BeanUtils!!!
Hudi(1.0、2.0)简介
如何运营独立站?








