当前位置:网站首页>rearrange 和 einsum 真的优雅吗
rearrange 和 einsum 真的优雅吗
2022-04-23 06:12:00 【wujpbb7】
结论是,还好吧。
从代码量看,差不多:
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import einsum
from einops import rearrange
class SimpleQKV(nn.Module):
def __init__(self, dim, use_ein):
super().__init__()
self.proj = nn.Linear(dim, dim*3, bias=False)
self.dim = dim
self.scale = self.dim ** -0.5
self.use_ein = use_ein
torch.manual_seed(777) # 为了使权重相同,便于比较输出
nn.init.xavier_uniform_(self.proj.weight)
def forward(self, x):
n,c,h,w = x.shape
#assert c==self.dim
if (self.use_ein):
x = rearrange(x, 'n c h w -> n (h w) c')
else:
x = x.permute(0,2,3,1).view(n, -1, c)
qkv = self.proj(x)
q,k,v = qkv.chunk(chunks=3,dim=-1)
if (self.use_ein):
attn = (einsum('n i c, n j c -> n i j', q, k) * self.scale).softmax(dim=-1)
v = einsum('n i j, n j c -> n i c', attn, v)
output = rearrange(v, 'n (h w) c -> n c h w', h=h)
else:
attn = (torch.matmul(q, k.transpose(1,2)) * self.scale).softmax(dim=-1)
v = torch.matmul(attn, v)
output = v.permute(0,2,1).view(n,c,h,w)
return output
batch, chan, height, width = 1, 20, 7, 7
simple_qkv_ein = SimpleQKV(chan, True)
simple_qkv_noein = SimpleQKV(chan, False)
x = torch.randn(batch, chan, height, width, device='cpu')
out1 = simple_qkv_ein(x)
out2 = simple_qkv_noein(x)
assert(out1.equal(out2))
# 保存onnx
simple_qkv_ein.eval()
onnx_filename = './simple_qkv_ein.onnx'
torch.onnx.export(simple_qkv_ein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
simple_qkv_noein.eval()
onnx_filename = './simple_qkv_noein.onnx'
torch.onnx.export(simple_qkv_noein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
print('save onnx succ.')
从保存的onnx看(经过 onnxsim 优化),也差不多:

版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://blog.csdn.net/blueblood7/article/details/121223135
边栏推荐
- torch. mm() torch. sparse. mm() torch. bmm() torch. Mul () torch The difference between matmul()
- Raspberry Pie: two color LED lamp experiment
- enforce fail at inline_container.cc:222
- Machine learning III: classification prediction based on logistic regression
- 【点云系列】Pointfilter: Point Cloud Filtering via Encoder-Decoder Modeling
- Pymysql connection database
- 1.2 初试PyTorch神经网络
- 【点云系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
- PyTorch 13. 嵌套函数和闭包(狗头)
- 第1章 NumPy基础
猜你喜欢

【點雲系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation

PaddleOCR 图片文字提取

Use originpro express for free

Paddleocr image text extraction

ThreadLocal, just look at me!

1.2 初试PyTorch神经网络

Mysql database installation and configuration details

第5 章 机器学习基础
![[2021 book recommendation] practical node red programming](/img/f4/e397c01f1551cd6c59ea4f54c197e6.png)
[2021 book recommendation] practical node red programming

【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds
随机推荐
【点云系列】Relationship-based Point Cloud Completion
第4章 Pytorch数据处理工具箱
ArcGIS License Server Administrator 无法启动解决方法
Compression and acceleration technology of deep learning model (I): parameter pruning
[dynamic programming] longest increasing subsequence
Summary of image classification white box anti attack technology
Visual studio 2019 installation and use
Computer shutdown program
红外传感器控制开关
[8] Assertion failed: dims.nbDims == 4 || dims.nbDims == 5
第8章 生成式深度学习
Pytorch best practices and coding style guide
【点云系列】Neural Opacity Point Cloud(NOPC)
[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
图像分类白盒对抗攻击技术总结
Pymysql connection database
[2021 book recommendation] red hat rhcsa 8 cert Guide: ex200
第2章 Pytorch基础1
Machine learning notes 1: learning ideas
torch.where能否传递梯度