当前位置:网站首页>Are realrange and einsum really elegant
Are realrange and einsum really elegant
2022-04-23 07:27:00 【wujpbb7】
The conclusion is that , O.k. .
In terms of the amount of code , almost :
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import einsum
from einops import rearrange
class SimpleQKV(nn.Module):
def __init__(self, dim, use_ein):
super().__init__()
self.proj = nn.Linear(dim, dim*3, bias=False)
self.dim = dim
self.scale = self.dim ** -0.5
self.use_ein = use_ein
torch.manual_seed(777) # To make the weights the same , Easy to compare output
nn.init.xavier_uniform_(self.proj.weight)
def forward(self, x):
n,c,h,w = x.shape
#assert c==self.dim
if (self.use_ein):
x = rearrange(x, 'n c h w -> n (h w) c')
else:
x = x.permute(0,2,3,1).view(n, -1, c)
qkv = self.proj(x)
q,k,v = qkv.chunk(chunks=3,dim=-1)
if (self.use_ein):
attn = (einsum('n i c, n j c -> n i j', q, k) * self.scale).softmax(dim=-1)
v = einsum('n i j, n j c -> n i c', attn, v)
output = rearrange(v, 'n (h w) c -> n c h w', h=h)
else:
attn = (torch.matmul(q, k.transpose(1,2)) * self.scale).softmax(dim=-1)
v = torch.matmul(attn, v)
output = v.permute(0,2,1).view(n,c,h,w)
return output
batch, chan, height, width = 1, 20, 7, 7
simple_qkv_ein = SimpleQKV(chan, True)
simple_qkv_noein = SimpleQKV(chan, False)
x = torch.randn(batch, chan, height, width, device='cpu')
out1 = simple_qkv_ein(x)
out2 = simple_qkv_noein(x)
assert(out1.equal(out2))
# preservation onnx
simple_qkv_ein.eval()
onnx_filename = './simple_qkv_ein.onnx'
torch.onnx.export(simple_qkv_ein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
simple_qkv_noein.eval()
onnx_filename = './simple_qkv_noein.onnx'
torch.onnx.export(simple_qkv_noein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
print('save onnx succ.')
From saved onnx see ( after onnxsim Optimize ), Also almost :

版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611550455.html
边栏推荐
- Gee configuring local development environment
- 画 ArcFace 中的 margin 曲线
- Chapter 2 pytoch foundation 1
- 基于51单片机的体脂检测系统设计(51+oled+hx711+us100)
- 【點雲系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
- swin transformer 转 onnx
- Proteus 8.10安装问题(亲测稳定不闪退!)
- unhandled system error, NCCL version 2.7.8
- excel实战应用案例100讲(八)-Excel的报表连接功能
- N states of prime number solution
猜你喜欢

AUTOSAR从入门到精通100讲(八十四)-UDS之时间参数总结篇

【指标】Precision、Recall

项目文件“ ”已被重命名或已不在解决方案中、未能找到与解决方案关联的源代码管理提供程序——两个工程问题

【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis

【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds

【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds

EasyUI combobox determines whether the input item exists in the drop-down list

Machine learning II: logistic regression classification based on Iris data set

【期刊会议系列】IEEE系列模板下载指南

RISCV MMU 概述
随机推荐
美摄助力百度“度咔剪辑”,让知识创作更容易
[3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion
Chapter 8 generative deep learning
SSL / TLS application example
AUTOSAR从入门到精通100讲(八十四)-UDS之时间参数总结篇
Device Tree 详解
【点云系列】Fully-Convolutional geometric features
Machine learning II: logistic regression classification based on Iris data set
《Multi-modal Visual Tracking:Review and Experimental Comparison》翻译
Mysql database installation and configuration details
How keras saves and loads the keras model
EMMC/SD学习小记
【无标题】制作一个0-99的计数器,P1.7接按键,P2接数码管段,共阳极数码管,P3.0,P3.1接数码管位码,每按一次键,数码管显示加一。请写出单片机的C51代码
PyTorch 12. hook的用法
Common regular expressions
Chapter 2 pytoch foundation 1
[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
安装 pycuda 出现 PEP517 的错误
【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds
Minesweeping games