当前位置:网站首页>swin transformer 转 onnx
swin transformer 转 onnx
2022-04-23 06:12:00 【wujpbb7】
swin transformer 代码:非官方实现,但是好理解 。
将训练好的 pth 转 onnx 代码:
import torch
from swin_transformer_pytorch import swin_t
pth_filename = './demo.pth' # 训练好的权重
onnx_filename = './demo.onnx'
net = swin_t()
weights = torch.load(pth_filename)
#net.load_state_dict(weights)
net.load_state_dict({k.replace('module.', ''): v for k, v in weights['embedding'].items()})
net.eval()
dummy_input = torch.randn(1, 3, 224, 224, device='cpu')
torch.onnx.export(net, dummy_input, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12,
dynamic_axes={'input':{0:"batch_size"},
'output':{0:"batch_size"}})
print('save onnx succ')
出现的错误:
1、“Exporting the operator roll to ONNX opset version 12 is not supported.”
修改 roll 为 cat:
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement
def forward(self, x):
#return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
x=torch.cat((x[:,-self.displacement:,:,:], x[:,:-self.displacement,:,:]), dim=1)
x=torch.cat((x[:,:,-self.displacement:,:], x[:,:,:-self.displacement,:]), dim=2)
return x
2、“RuntimeError: Expected node type 'onnx::Constant', got 'onnx::Cast'.”
把 “对切片做自加减法” 替换成 cat:
class WindowAttention(nn.Module):
...
def forward(self, x):
...
#if self.shifted:
#dots[:, :, -nw_w:] += self.upper_lower_mask
#dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
if self.shifted:
dots = rearrange(dots, 'b c (n_h n_w) h w -> b c n_h n_w h w', n_h=nw_h, n_w=nw_w)
dots = torch.cat((dots[:, :, :-1], dots[:, :, -1:] + self.upper_lower_mask), dim=2)
dots = dots.permute(0,1,3,2,4,5)
dots = torch.cat((dots[:, :, :-1], dots[:, :, -1:] + self.left_right_mask), dim=2)
dots = dots.permute(0,1,3,2,4,5)
dots = rearrange(dots, 'b c n_h n_w h w -> b c (n_h n_w) h w')
...
参考:
版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://blog.csdn.net/blueblood7/article/details/121034635
边栏推荐
- [dynamic programming] longest increasing subsequence
- MySQL数据库安装与配置详解
- [point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
- Common regular expressions
- Chapter 1 numpy Foundation
- Pytorch trains the basic process of a network in five steps
- Raspberry Pie: two color LED lamp experiment
- [Point Cloud Series] SG - Gan: Adversarial Self - attachment GCN for Point Cloud Topological parts Generation
- 【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis
- SSL/TLS应用示例
猜你喜欢
[recommendation for new books in 2021] professional azure SQL managed database administration
【点云系列】点云隐式表达相关论文概要
给女朋友写个微信双开小工具
【点云系列】PnP-3D: A Plug-and-Play for 3D Point Clouds
Gephi教程【1】安装
Mysql database installation and configuration details
【點雲系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5
机器学习——PCA与LDA
Chapter 8 generative deep learning
随机推荐
【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis
如何对多维矩阵进行标准化(基于numpy)
C language, a number guessing game
torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
第1章 NumPy基础
第8章 生成式深度学习
xcode 编译速度慢的解决办法
[dynamic programming] Yang Hui triangle
PyTorch 13. 嵌套函数和闭包(狗头)
GEE配置本地开发环境
PyTorch 14. module类
Chapter 4 pytoch data processing toolbox
【动态规划】三角形最小路径和
PyTorch 20. PyTorch技巧(持续更新)
ThreadLocal, just look at me!
【動態規劃】不同路徑2
Mysql database installation and configuration details
torch. mm() torch. sparse. mm() torch. bmm() torch. Mul () torch The difference between matmul()
【指标】Precision、Recall
【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds