当前位置:网站首页>PyTorch 19. PyTorch中相似操作的区别与联系
PyTorch 19. PyTorch中相似操作的区别与联系
2022-04-23 06:11:00 【DCGJ666】
PyTorch 19. PyTorch中相似操作的区别与联系
view() 和 reshape()
写在开头:
有一篇大佬的总结非常到位:博客
总结
- view() 在操作tensor时,需要tensor是内存连续的,而且在进行尺寸变换时,view()操作不会新开辟内存空间。但是要保证tensor连续,对tensor进行
tensor.contiguous()
时,会开辟新的内存空间,存放内存连续的数据。 - reshape()操作,与view()的作用一模一样,但是它比view()更高级,被操作的tensor是内存连续时,直接采用reshape不会开辟新的内存;被操作的tensor不是内存连续时,reshape操作会开辟新的内存,再对tensor进行reshape。
- 最后,用reshape操作就完事了
expand()和repeat()
expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小等于1的维度扩展到更大的尺寸
例子:
import torch
x = torch.tensor([1, 2, 3])
x.expand(2,3)
tensor([[1, 2, 3],
[1, 2, 3]])
注意 expand()只能扩展维度为1的维数,维数不为1的部分要保持一致
repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
例子
import torch
x = torch.tensor([1, 2, 3])
x.repeat(3,2)
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
x2 = torch.randn(2, 3, 4)
x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])
乘法操作
pytorch中的乘法操作有:torch.mm(), torch.bmm(), torch.matmul(), torch.mul(), 运算符,以及torch.einsum()
二维矩阵乘法 torch.mm()
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
torch.mm(mat1, mat2, out=None), 其中mat1为(nxm),mat2为(mxd),输出维度是(nxd)
三维带batch的矩阵乘法torch.bmm()
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
由于神经网络训练一般采用mini-batch,经常输入的三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
为(bxnxm),bmat2
为(bxmxd),输出out
的维度是(bxnxd)
多维矩阵乘法 torch.matmul()
torch,matmul(input, other, out=None)支持broadcast操作
针对多维数据matmul()
乘法,可以认为该matmul()
乘法使用两个参数的后两个维度计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别为input->(100x500x99x11)
,other->(500x11x99)
那么我们可以认为该乘法首先进行后两位矩阵乘法得到(99x11)x(11x99)->(99,99)
,然后分析两个参数的batch size分别为(1000x500)
和500
,可以广播为(1000x500)
,因此最终输出的维度是(1000x500x99x99)
矩阵逐元素(Element-wise)乘法torch.mul()
函数torch.mul(mat1, other, out=None)
,其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast即可。
两个运算符@和*
- @:矩阵乘法,自动执行合适的矩阵乘法函数
- *:elemnet-wise乘法
register_parameter()和parameter()
- Parameter()
Parameter是Tensor, 即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度
# 随便定义一个网络
net = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3,1))
# list让它可以访问
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
- register_parameter(name, parameters)
向建立的网络module添加parameter
最大的区别:parameter可以通过注册网络时候的name获取
例子
class Example(nn.Module):
def __init__(self):
super(Example, self).__init__()
self.W1_params = nn.Parameter(torch.rand(2,3))
self.register_parameter('W2_params', nn.Parameter(torch.rand(2,3)))
def forward(self, x):
return x
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/121807973
边栏推荐
- C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
- 【点云系列】Multi-view Neural Human Rendering (NHR)
- Write a wechat double open gadget to your girlfriend
- [dynamic programming] longest increasing subsequence
- Google AdMob advertising learning
- 第1章 NumPy基础
- 【点云系列】 场景识别类导读
- 【动态规划】不同的二叉搜索树
- Component learning (2) arouter principle learning
- 【Tensorflow】共享机制
猜你喜欢
[point cloud series] sg-gan: advantageous self attention GCN for point cloud topological parts generation
Visual studio 2019 installation and use
1.2 preliminary pytorch neural network
微信小程序 使用wxml2canvas插件生成图片部分问题记录
Miscellaneous learning
MySQL数据库安装与配置详解
Android interview Online Economic encyclopedia [constantly updating...]
画 ArcFace 中的 margin 曲线
Chapter 8 generative deep learning
【2021年新书推荐】Artificial Intelligence for IoT Cookbook
随机推荐
PyTorch 22. PyTorch常用代码段合集
C language, a number guessing game
Android interview Online Economic encyclopedia [constantly updating...]
PyTorch 9. 优化器
PyTorch训练一个网络的基本流程5步法
Google AdMob advertising learning
机器学习 三: 基于逻辑回归的分类预测
PyTorch最佳实践和代码编写风格指南
PyTorch 11.正则化
画 ArcFace 中的 margin 曲线
Mysql database installation and configuration details
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds
[dynamic programming] different paths 2
第1章 NumPy基础
Component based learning (3) path and group annotations in arouter
【点云系列】Relationship-based Point Cloud Completion
MySQL数据库安装与配置详解
C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
[point cloud series] pnp-3d: a plug and play for 3D point clouds