当前位置:网站首页>PyTorch 19. PyTorch中相似操作的区别与联系
PyTorch 19. PyTorch中相似操作的区别与联系
2022-04-23 06:11:00 【DCGJ666】
PyTorch 19. PyTorch中相似操作的区别与联系
view() 和 reshape()
写在开头:
有一篇大佬的总结非常到位:博客
总结
- view() 在操作tensor时,需要tensor是内存连续的,而且在进行尺寸变换时,view()操作不会新开辟内存空间。但是要保证tensor连续,对tensor进行
tensor.contiguous()时,会开辟新的内存空间,存放内存连续的数据。 - reshape()操作,与view()的作用一模一样,但是它比view()更高级,被操作的tensor是内存连续时,直接采用reshape不会开辟新的内存;被操作的tensor不是内存连续时,reshape操作会开辟新的内存,再对tensor进行reshape。
- 最后,用reshape操作就完事了
expand()和repeat()
expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小等于1的维度扩展到更大的尺寸
例子:
import torch
x = torch.tensor([1, 2, 3])
x.expand(2,3)
tensor([[1, 2, 3],
[1, 2, 3]])
注意 expand()只能扩展维度为1的维数,维数不为1的部分要保持一致
repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
例子
import torch
x = torch.tensor([1, 2, 3])
x.repeat(3,2)
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
x2 = torch.randn(2, 3, 4)
x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])
乘法操作
pytorch中的乘法操作有:torch.mm(), torch.bmm(), torch.matmul(), torch.mul(), 运算符,以及torch.einsum()
二维矩阵乘法 torch.mm()
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
torch.mm(mat1, mat2, out=None), 其中mat1为(nxm),mat2为(mxd),输出维度是(nxd)
三维带batch的矩阵乘法torch.bmm()
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
由于神经网络训练一般采用mini-batch,经常输入的三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1为(bxnxm),bmat2为(bxmxd),输出out的维度是(bxnxd)
多维矩阵乘法 torch.matmul()
torch,matmul(input, other, out=None)支持broadcast操作
针对多维数据matmul()乘法,可以认为该matmul()乘法使用两个参数的后两个维度计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别为input->(100x500x99x11),other->(500x11x99)那么我们可以认为该乘法首先进行后两位矩阵乘法得到(99x11)x(11x99)->(99,99),然后分析两个参数的batch size分别为(1000x500)和500,可以广播为(1000x500),因此最终输出的维度是(1000x500x99x99)
矩阵逐元素(Element-wise)乘法torch.mul()
函数torch.mul(mat1, other, out=None),其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast即可。
两个运算符@和*
- @:矩阵乘法,自动执行合适的矩阵乘法函数
- *:elemnet-wise乘法
register_parameter()和parameter()
- Parameter()
Parameter是Tensor, 即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度
# 随便定义一个网络
net = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3,1))
# list让它可以访问
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
- register_parameter(name, parameters)
向建立的网络module添加parameter
最大的区别:parameter可以通过注册网络时候的name获取
例子
class Example(nn.Module):
def __init__(self):
super(Example, self).__init__()
self.W1_params = nn.Parameter(torch.rand(2,3))
self.register_parameter('W2_params', nn.Parameter(torch.rand(2,3)))
def forward(self, x):
return x
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/121807973
边栏推荐
- [8] Assertion failed: dims.nbDims == 4 || dims.nbDims == 5
- How to standardize multidimensional matrix (based on numpy)
- PaddleOCR 图片文字提取
- Exploration of SendMessage principle of advanced handler
- Component based learning (1) idea and Implementation
- 【动态规划】最长递增子序列
- Cancel remote dependency and use local dependency
- 1.2 初试PyTorch神经网络
- 机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
- 给女朋友写个微信双开小工具
猜你喜欢

C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported

第1章 NumPy基础

第2章 Pytorch基础1

MySQL的安装与配置——详细教程

WebView displays a blank due to a certificate problem
![[2021 book recommendation] learn winui 3.0](/img/1c/ca7e05946613e9eb2b8c24d121c2e1.png)
[2021 book recommendation] learn winui 3.0

C language, a number guessing game
![[recommendation of new books in 2021] enterprise application development with C 9 and NET 5](/img/1d/cc673ca857fff3c5c48a51883d96c4.png)
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5

图像分类白盒对抗攻击技术总结
![[2021 book recommendation] kubernetes in production best practices](/img/78/2b5bf03adad5da9a109ea5d4e56b18.png)
[2021 book recommendation] kubernetes in production best practices
随机推荐
第2章 Pytorch基础2
torch_ Geometric learning 1, messagepassing
第2章 Pytorch基础1
机器学习——PCA与LDA
Easyui combobox 判断输入项是否存在于下拉列表中
Migrating your native/mobile application to Unified Plan/WebRTC 1.0 API
深度学习模型压缩与加速技术(一):参数剪枝
【点云系列】点云隐式表达相关论文概要
【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds
ArcGIS License Server Administrator 无法启动解决方法
Chapter 2 pytoch foundation 2
[point cloud series] a rotation invariant framework for deep point cloud analysis
【点云系列】Neural Opacity Point Cloud(NOPC)
1.2 preliminary pytorch neural network
N states of prime number solution
【點雲系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
1.1 PyTorch和神经网络
PyTorch 13. 嵌套函数和闭包(狗头)
Pytorch model pruning example tutorial III. multi parameter and global pruning
torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别