当前位置:网站首页>rearrange 和 einsum 真的优雅吗
rearrange 和 einsum 真的优雅吗
2022-04-23 06:12:00 【wujpbb7】
结论是,还好吧。
从代码量看,差不多:
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import einsum
from einops import rearrange
class SimpleQKV(nn.Module):
def __init__(self, dim, use_ein):
super().__init__()
self.proj = nn.Linear(dim, dim*3, bias=False)
self.dim = dim
self.scale = self.dim ** -0.5
self.use_ein = use_ein
torch.manual_seed(777) # 为了使权重相同,便于比较输出
nn.init.xavier_uniform_(self.proj.weight)
def forward(self, x):
n,c,h,w = x.shape
#assert c==self.dim
if (self.use_ein):
x = rearrange(x, 'n c h w -> n (h w) c')
else:
x = x.permute(0,2,3,1).view(n, -1, c)
qkv = self.proj(x)
q,k,v = qkv.chunk(chunks=3,dim=-1)
if (self.use_ein):
attn = (einsum('n i c, n j c -> n i j', q, k) * self.scale).softmax(dim=-1)
v = einsum('n i j, n j c -> n i c', attn, v)
output = rearrange(v, 'n (h w) c -> n c h w', h=h)
else:
attn = (torch.matmul(q, k.transpose(1,2)) * self.scale).softmax(dim=-1)
v = torch.matmul(attn, v)
output = v.permute(0,2,1).view(n,c,h,w)
return output
batch, chan, height, width = 1, 20, 7, 7
simple_qkv_ein = SimpleQKV(chan, True)
simple_qkv_noein = SimpleQKV(chan, False)
x = torch.randn(batch, chan, height, width, device='cpu')
out1 = simple_qkv_ein(x)
out2 = simple_qkv_noein(x)
assert(out1.equal(out2))
# 保存onnx
simple_qkv_ein.eval()
onnx_filename = './simple_qkv_ein.onnx'
torch.onnx.export(simple_qkv_ein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
simple_qkv_noein.eval()
onnx_filename = './simple_qkv_noein.onnx'
torch.onnx.export(simple_qkv_noein, x, onnx_filename,
input_names=['input'], output_names=['ouput'],
export_params=True, verbose=False, opset_version=12)
print('save onnx succ.')
从保存的onnx看(经过 onnxsim 优化),也差不多:
版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://blog.csdn.net/blueblood7/article/details/121223135
边栏推荐
猜你喜欢
【点云系列】点云隐式表达相关论文概要
Gee configuring local development environment
ArcGIS License Server Administrator 无法启动解决方法
【点云系列】Learning Representations and Generative Models for 3D pointclouds
Summary of image classification white box anti attack technology
Chapter 4 pytoch data processing toolbox
Chapter 8 generative deep learning
Easyui combobox 判断输入项是否存在于下拉列表中
【点云系列】Neural Opacity Point Cloud(NOPC)
Record WebView shows another empty pit
随机推荐
素数求解的n种境界
Chapter 4 pytoch data processing toolbox
【动态规划】不同的二叉搜索树
【动态规划】三角形最小路径和
Machine learning III: classification prediction based on logistic regression
Binder mechanism principle
【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds
[point cloud series] a rotation invariant framework for deep point cloud analysis
[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
第2章 Pytorch基础2
Visual Studio 2019安装与使用
三子棋小游戏
[3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion
免费使用OriginPro学习版
【点云系列】Fully-Convolutional geometric features
Pymysql connection database
SSL/TLS应用示例
面试总结之特征工程
Summary of image classification white box anti attack technology
【点云系列】Pointfilter: Point Cloud Filtering via Encoder-Decoder Modeling