当前位置:网站首页>torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
2022-04-23 06:11:00 【小风_】
torch.mm()
二维矩阵的乘法,假设输入矩阵mat1
维度是 ( m × n ) (m×n) (m×n),矩阵mat2
维度是 ( n × p ) (n×p) (n×p),则输出维度为 ( m × p ) (m×p) (m×p),只能是二维的
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
out = torch.mm(mat1, mat2)
''' tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]]) '''
torch.sparse.mm()
a是稀疏矩阵,b是稀疏矩阵或者密集矩阵,sparse.mm
的作用和torch.mm
一样,都是做矩阵乘法计算
a = torch.randn(2, 3).to_sparse().requires_grad_(True)
b = torch.randn(3, 2, requires_grad=True)
y = torch.sparse.mm(a, b)
''' a: tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]), size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True) b: tensor([[-0.6479, 0.7874], [-1.2056, 0.5641], [-1.1716, -0.9923]], requires_grad=True) y: tensor([[-0.3323, 1.8723], [-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>) '''
torch.bmm()
与torch.mm
类似,但多了一个batch_size
维度,输入矩阵张量mat1
维度是 ( b × m × n ) (b×m×n) (b×m×n),矩阵张量mat2
维度是 ( b × n × p ) (b×n×p) (b×n×p),则输出维度为 ( b × m × p ) (b×m×p) (b×m×p)
mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(mat1, mat2)
print(res.size())
# torch.Size([10, 3, 5])
torch.mul()
将输入张量input
的每个元素与另一个标量other
相乘,返回一个新的张量out
,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
a = torch.randn(3)
torch.mul(a, 100)
''' a: tensor([ 0.2015, -0.4255, 2.6087]) tensor([ 20.1494, -42.5491, 260.8663]) '''
# 方式2:张量 和 张量(需满足广播规则)
a = torch.randn(4, 1)
b = torch.randn(1, 4)
c = torch.mul(a,b)
''' c: tensor([[-0.1183, -0.4246, -0.0512, 0.1757], [-0.4215, -1.5121, -0.1823, 0.6257], [-0.0358, -0.1284, -0.0155, 0.0531], [ 0.1649, 0.5917, 0.0713, -0.2448]]) '''
# 方式3:元素对应项相乘
a = torch.randn(3, 2)
b = torch.randn(3, 2)
c = torch.mul(a,b)
''' C: tensor([[-1.9259, -0.0116], [-1.8523, -0.0392], [-0.4881, -0.4235]]) '''
torch.matmul()
两个张量的矩阵乘积。其行为取决于张量的维数如下:
- 如果两个张量都是一维的,则返回点积(标量)。
- 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
- 如果第一个参数是一维的,第二个参数是二维的,则在其维数前加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。
- 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量乘积。
- 如果两个参数至少是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
总结
- 二维矩阵乘积用
torch.mm()
或torch.sparse.mm()
- 多批次的二维矩阵之间的乘积用
torch.bmm()
- 标量乘积或对应项乘积用
torch.mul()
- 批次或广播进行乘积用
torch.matmul()
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/120710801
边栏推荐
- 个人博客网站搭建
- Encapsulate a set of project network request framework from 0
- C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
- ViewPager2实现画廊效果执行notifyDataSetChanged后PageTransformer显示异常 界面变形问题
- Record WebView shows another empty pit
- Thanos.sh灭霸脚本,轻松随机删除系统一半的文件
- AVD Pixel_ 2_ API_ 24 is already running. If that is not the case, delete the files at C:\Users\admi
- [2021 book recommendation] learn winui 3.0
- 素数求解的n种境界
- [recommendation of new books in 2021] enterprise application development with C 9 and NET 5
猜你喜欢
实习做了啥
开篇:双指针仪表盘的识别
Itop4412 HDMI display (4.0.3_r1)
Component learning (2) arouter principle learning
ArcGIS License Server Administrator 无法启动解决方法
Binder机制原理
iTOP4412 HDMI显示(4.0.3_r1)
Ffmpeg common commands
C connection of new world Internet of things cloud platform (simple understanding version)
Android interview Online Economic encyclopedia [constantly updating...]
随机推荐
如何对多维矩阵进行标准化(基于numpy)
MySQL笔记5_操作数据
Fill the network gap
Using queue to realize stack
PyTorch 模型剪枝实例教程三、多参数与全局剪枝
Three methods to realize the rotation of ImageView with its own center as the origin
三种实现ImageView以自身中心为原点旋转的方法
webView因证书问题显示一片空白
红外传感器控制开关
Easyui combobox 判断输入项是否存在于下拉列表中
iTOP4412 SurfaceFlinger(4.4.4_r1)
Itop4412 kernel restarts repeatedly
Recyclerview batch update view: notifyitemrangeinserted, notifyitemrangeremoved, notifyitemrangechanged
Android-Room数据库快速上手
xcode 编译速度慢的解决办法
PaddleOCR 图片文字提取
“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated
ViewPager2实现画廊效果执行notifyDataSetChanged后PageTransformer显示异常 界面变形问题
[Andorid] 通过JNI实现kernel与app进行spi通讯
ffmpeg常用命令