当前位置:网站首页>torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别
2022-04-23 06:11:00 【小风_】
torch.mm()
二维矩阵的乘法,假设输入矩阵mat1维度是 ( m × n ) (m×n) (m×n),矩阵mat2维度是 ( n × p ) (n×p) (n×p),则输出维度为 ( m × p ) (m×p) (m×p),只能是二维的
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
out = torch.mm(mat1, mat2)
''' tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]]) '''
torch.sparse.mm()
a是稀疏矩阵,b是稀疏矩阵或者密集矩阵,sparse.mm的作用和torch.mm一样,都是做矩阵乘法计算
a = torch.randn(2, 3).to_sparse().requires_grad_(True)
b = torch.randn(3, 2, requires_grad=True)
y = torch.sparse.mm(a, b)
''' a: tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]), size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True) b: tensor([[-0.6479, 0.7874], [-1.2056, 0.5641], [-1.1716, -0.9923]], requires_grad=True) y: tensor([[-0.3323, 1.8723], [-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>) '''
torch.bmm()
与torch.mm类似,但多了一个batch_size维度,输入矩阵张量mat1维度是 ( b × m × n ) (b×m×n) (b×m×n),矩阵张量mat2维度是 ( b × n × p ) (b×n×p) (b×n×p),则输出维度为 ( b × m × p ) (b×m×p) (b×m×p)
mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(mat1, mat2)
print(res.size())
# torch.Size([10, 3, 5])
torch.mul()
将输入张量input的每个元素与另一个标量other相乘,返回一个新的张量out,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
a = torch.randn(3)
torch.mul(a, 100)
''' a: tensor([ 0.2015, -0.4255, 2.6087]) tensor([ 20.1494, -42.5491, 260.8663]) '''
# 方式2:张量 和 张量(需满足广播规则)
a = torch.randn(4, 1)
b = torch.randn(1, 4)
c = torch.mul(a,b)
''' c: tensor([[-0.1183, -0.4246, -0.0512, 0.1757], [-0.4215, -1.5121, -0.1823, 0.6257], [-0.0358, -0.1284, -0.0155, 0.0531], [ 0.1649, 0.5917, 0.0713, -0.2448]]) '''
# 方式3:元素对应项相乘
a = torch.randn(3, 2)
b = torch.randn(3, 2)
c = torch.mul(a,b)
''' C: tensor([[-1.9259, -0.0116], [-1.8523, -0.0392], [-0.4881, -0.4235]]) '''
torch.matmul()
两个张量的矩阵乘积。其行为取决于张量的维数如下:
- 如果两个张量都是一维的,则返回点积(标量)。
- 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
- 如果第一个参数是一维的,第二个参数是二维的,则在其维数前加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。
- 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量乘积。
- 如果两个参数至少是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
总结
- 二维矩阵乘积用
torch.mm()或torch.sparse.mm() - 多批次的二维矩阵之间的乘积用
torch.bmm() - 标量乘积或对应项乘积用
torch.mul() - 批次或广播进行乘积用
torch.matmul()
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/120710801
边栏推荐
- ViewPager2实现画廊效果执行notifyDataSetChanged后PageTransformer显示异常 界面变形问题
- ARGB透明度换算
- DCMTK (dcm4che) works together with dicoogle
- [2021 book recommendation] kubernetes in production best practices
- 机器学习 三: 基于逻辑回归的分类预测
- Component based learning (1) idea and Implementation
- AVD Pixel_2_API_24 is already running.If that is not the case, delete the files at C:\Users\admi
- 树莓派:双色LED灯实验
- iTOP4412无法显示开机动画(4.0.3_r1)
- Migrating your native/mobile application to Unified Plan/WebRTC 1.0 API
猜你喜欢

Cancel remote dependency and use local dependency

adb shell top 命令详解

PaddleOCR 图片文字提取

this.getOptions is not a function

【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
树莓派:双色LED灯实验

【2021年新书推荐】Artificial Intelligence for IoT Cookbook

Cause: dx.jar is missing

【2021年新书推荐】Practical Node-RED Programming

ThreadLocal,看我就够了!
随机推荐
如何对多维矩阵进行标准化(基于numpy)
c语言编写一个猜数字游戏编写
利用栈实现队列的出队入队
iTOP4412 HDMI显示(4.4.4_r1)
SSL/TLS应用示例
js时间获取本周一、周日,判断时间是今天,今天前、后
Fill the network gap
ArcGIS License Server Administrator 无法启动解决方法
face_recognition人脸检测
MySQL笔记4_主键自增长(auto_increment)
利用队列实现栈
【2021年新书推荐】Kubernetes in Production Best Practices
补补网络缺口
【2021年新书推荐】Professional Azure SQL Managed Database Administration
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5
ProcessBuilder工具类
【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
org.xml.sax.SAXParseException; lineNumber: 141; columnNumber: 252; cvc-complex-type.2.4.a: 发现了以元素 ‘b
Using queue to realize stack
DCMTK (dcm4che) works together with dicoogle