当前位置:网站首页>PyTorch 19. PyTorch中相似操作的区别与联系
PyTorch 19. PyTorch中相似操作的区别与联系
2022-04-23 06:11:00 【DCGJ666】
PyTorch 19. PyTorch中相似操作的区别与联系
view() 和 reshape()
写在开头:
有一篇大佬的总结非常到位:博客
总结
- view() 在操作tensor时,需要tensor是内存连续的,而且在进行尺寸变换时,view()操作不会新开辟内存空间。但是要保证tensor连续,对tensor进行
tensor.contiguous()
时,会开辟新的内存空间,存放内存连续的数据。 - reshape()操作,与view()的作用一模一样,但是它比view()更高级,被操作的tensor是内存连续时,直接采用reshape不会开辟新的内存;被操作的tensor不是内存连续时,reshape操作会开辟新的内存,再对tensor进行reshape。
- 最后,用reshape操作就完事了
expand()和repeat()
expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小等于1的维度扩展到更大的尺寸
例子:
import torch
x = torch.tensor([1, 2, 3])
x.expand(2,3)
tensor([[1, 2, 3],
[1, 2, 3]])
注意 expand()只能扩展维度为1的维数,维数不为1的部分要保持一致
repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
例子
import torch
x = torch.tensor([1, 2, 3])
x.repeat(3,2)
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
x2 = torch.randn(2, 3, 4)
x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])
乘法操作
pytorch中的乘法操作有:torch.mm(), torch.bmm(), torch.matmul(), torch.mul(), 运算符,以及torch.einsum()
二维矩阵乘法 torch.mm()
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
torch.mm(mat1, mat2, out=None), 其中mat1为(nxm),mat2为(mxd),输出维度是(nxd)
三维带batch的矩阵乘法torch.bmm()
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
由于神经网络训练一般采用mini-batch,经常输入的三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
为(bxnxm),bmat2
为(bxmxd),输出out
的维度是(bxnxd)
多维矩阵乘法 torch.matmul()
torch,matmul(input, other, out=None)支持broadcast操作
针对多维数据matmul()
乘法,可以认为该matmul()
乘法使用两个参数的后两个维度计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别为input->(100x500x99x11)
,other->(500x11x99)
那么我们可以认为该乘法首先进行后两位矩阵乘法得到(99x11)x(11x99)->(99,99)
,然后分析两个参数的batch size分别为(1000x500)
和500
,可以广播为(1000x500)
,因此最终输出的维度是(1000x500x99x99)
矩阵逐元素(Element-wise)乘法torch.mul()
函数torch.mul(mat1, other, out=None)
,其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast即可。
两个运算符@和*
- @:矩阵乘法,自动执行合适的矩阵乘法函数
- *:elemnet-wise乘法
register_parameter()和parameter()
- Parameter()
Parameter是Tensor, 即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度
# 随便定义一个网络
net = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3,1))
# list让它可以访问
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
- register_parameter(name, parameters)
向建立的网络module添加parameter
最大的区别:parameter可以通过注册网络时候的name获取
例子
class Example(nn.Module):
def __init__(self):
super(Example, self).__init__()
self.W1_params = nn.Parameter(torch.rand(2,3))
self.register_parameter('W2_params', nn.Parameter(torch.rand(2,3)))
def forward(self, x):
return x
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/121807973
边栏推荐
- 【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis
- [dynamic programming] triangle minimum path sum
- Record WebView shows another empty pit
- Android interview Online Economic encyclopedia [constantly updating...]
- MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
- What did you do during the internship
- Pytorch model pruning example tutorial III. multi parameter and global pruning
- 三子棋小游戏
- 【动态规划】三角形最小路径和
- Chapter 2 pytoch foundation 1
猜你喜欢
Machine learning notes 1: learning ideas
【2021年新书推荐】Red Hat Certified Engineer (RHCE) Study Guide
【点云系列】DeepMapping: Unsupervised Map Estimation From Multiple Point Clouds
微信小程序 使用wxml2canvas插件生成图片部分问题记录
树莓派:双色LED灯实验
Record WebView shows another empty pit
Project, how to package
1.1 pytorch and neural network
[2021 book recommendation] effortless app development with Oracle visual builder
MySQL的安装与配置——详细教程
随机推荐
1.2 preliminary pytorch neural network
Handlerthread principle and practical application
MySQL notes 4_ Primary key auto_increment
winform滚动条美化
“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated
1.1 pytorch and neural network
N states of prime number solution
【2021年新书推荐】Red Hat Certified Engineer (RHCE) Study Guide
【Tensorflow】共享机制
PyTorch训练一个网络的基本流程5步法
【点云系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
扫雷小游戏
Component learning (2) arouter principle learning
ThreadLocal, just look at me!
[point cloud series] pnp-3d: a plug and play for 3D point clouds
WebRTC ICE candidate里面的raddr和rport表示什么?
PyTorch 模型剪枝实例教程三、多参数与全局剪枝
[dynamic programming] triangle minimum path sum
【点云系列】Neural Opacity Point Cloud(NOPC)
机器学习——模型优化