当前位置:网站首页>swin transformer 转 onnx
swin transformer 转 onnx
2022-04-23 06:12:00 【wujpbb7】
swin transformer 代码:非官方实现,但是好理解 。
将训练好的 pth 转 onnx 代码:
import torch
from swin_transformer_pytorch import swin_t
pth_filename = './demo.pth' # 训练好的权重
onnx_filename = './demo.onnx'
net = swin_t()
weights = torch.load(pth_filename)
#net.load_state_dict(weights)
net.load_state_dict({k.replace('module.', ''): v for k, v in weights['embedding'].items()})
net.eval()
dummy_input = torch.randn(1, 3, 224, 224, device='cpu')
torch.onnx.export(net, dummy_input, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12,
dynamic_axes={'input':{0:"batch_size"},
'output':{0:"batch_size"}})
print('save onnx succ')
出现的错误:
1、“Exporting the operator roll to ONNX opset version 12 is not supported.”
修改 roll 为 cat:
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement
def forward(self, x):
#return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
x=torch.cat((x[:,-self.displacement:,:,:], x[:,:-self.displacement,:,:]), dim=1)
x=torch.cat((x[:,:,-self.displacement:,:], x[:,:,:-self.displacement,:]), dim=2)
return x
2、“RuntimeError: Expected node type 'onnx::Constant', got 'onnx::Cast'.”
把 “对切片做自加减法” 替换成 cat:
class WindowAttention(nn.Module):
...
def forward(self, x):
...
#if self.shifted:
#dots[:, :, -nw_w:] += self.upper_lower_mask
#dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
if self.shifted:
dots = rearrange(dots, 'b c (n_h n_w) h w -> b c n_h n_w h w', n_h=nw_h, n_w=nw_w)
dots = torch.cat((dots[:, :, :-1], dots[:, :, -1:] + self.upper_lower_mask), dim=2)
dots = dots.permute(0,1,3,2,4,5)
dots = torch.cat((dots[:, :, :-1], dots[:, :, -1:] + self.left_right_mask), dim=2)
dots = dots.permute(0,1,3,2,4,5)
dots = rearrange(dots, 'b c n_h n_w h w -> b c (n_h n_w) h w')
...
参考:
版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://blog.csdn.net/blueblood7/article/details/121034635
边栏推荐
- 1.1 pytorch and neural network
- Visual studio 2019 installation and use
- 【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds
- Android exposed components - ignored component security
- WebView displays a blank due to a certificate problem
- C language, a number guessing game
- Chapter 2 pytoch foundation 1
- Handlerthread principle and practical application
- 第3章 Pytorch神经网络工具箱
- PyTorch 18. torch.backends.cudnn
猜你喜欢
![[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation](/img/1d/92aa044130d8bd86b9ea6c57dc8305.png)
[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
![[2021 book recommendation] kubernetes in production best practices](/img/78/2b5bf03adad5da9a109ea5d4e56b18.png)
[2021 book recommendation] kubernetes in production best practices

【点云系列】Fully-Convolutional geometric features

Gephi教程【1】安装

【点云系列】Relationship-based Point Cloud Completion

第4章 Pytorch数据处理工具箱

WebView displays a blank due to a certificate problem

Write a wechat double open gadget to your girlfriend

使用 trt 的int8 量化和推断 onnx 模型

【点云系列】Learning Representations and Generative Models for 3D pointclouds
随机推荐
PaddleOCR 图片文字提取
Visual Studio 2019安装与使用
cmder中文乱码问题
【动态规划】杨辉三角
利用官方torch版GCN训练并测试cora数据集
【期刊会议系列】IEEE系列模板下载指南
PyTorch 9. 优化器
机器学习——PCA与LDA
Gephi教程【1】安装
机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
Infrared sensor control switch
PyTorch 10. 学习率
[dynamic programming] longest increasing subsequence
【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds
torch_ Geometric learning 1, messagepassing
【点云系列】 场景识别类导读
【动态规划】不同的二叉搜索树
【点云系列】Fully-Convolutional geometric features
Binder mechanism principle
【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds