当前位置:网站首页>第2章 Pytorch基础1
第2章 Pytorch基础1
2022-04-23 06:11:00 【sunshinecxm_BJTU】
2.4 Numpy与Tensor
2.4.1 Tensor概述
对tensor的操作很多,从接口的角度来划分,可以分为两类:
(1)torch.function,如torch.sum、torch.add等,
(2)tensor.function,如tensor.view、tensor.add等。
这些操作对大部分tensor都是等价的,如torch.add(x,y)与x.add(y)等价。实际使用中可以根据个人爱好选择。
如果从修改方式的角度,可以分为以下两类:
(1)不修改自身数据,如x.add(y),x的数据不变,返回一个新的tensor。
(2)修改自身数据,如x.add_(y)(运行符带下划线后缀),运算结果存在x中,x被修改。
import torch
x=torch.tensor([1,2])
y=torch.tensor([3,4])
z=x.add(y)
print(z)
print(x)
x.add_(y)
print(x)
运行结果
tensor([4, 6])
tensor([1, 2])
tensor([4, 6])
2.4.2 创建Tensor
新建tensor的方法很多,可以从列表或ndarray等类型进行构建,也可根据指定的形状构建。常见的构建tensor的方法,可参考表2-1。
表2-1 常见的新建tensor方法
import torch
#根据list数据生成tensor
torch.Tensor([1,2,3,4,5,6])
#根据指定形状生成tensor
torch.Tensor(2,3)
#根据给定的tensor的形状
t=torch.Tensor([[1,2,3],[4,5,6]])
#查看tensor的形状
t.size()
#shape与size()等价方式
t.shape
#根据已有形状创建tensor
torch.Tensor(t.size())
【说明】注意torch.Tensor与torch.tensor的几点区别
①torch.Tensor是torch.empty和torch.tensor之间的一种混合,但是,当传入数据时,torch.Tensor使用全局默认dtype(FloatTensor),torch.tensor从数据中推断数据类型。
②torch.tensor(1)返回一个固定值1,而torch.Tensor(1)返回一个大小为1的张量,它是随机初始化的值。
import torch
t1=torch.Tensor(1)
t2=torch.tensor(1)
print("t1的值{},t1的数据类型{}".format(t1,t1.type()))
print("t2的值{},t2的数据类型{}".format(t2,t2.type()))
运行结果
t1的值tensor([3.5731e-20]),t1的数据类型torch.FloatTensor
t2的值1,t2的数据类型torch.LongTensor
根据一定规则,自动生成tensor的一些例子。
import torch
#生成一个单位矩阵
torch.eye(2,2)
#自动生成全是0的矩阵
torch.zeros(2,3)
#根据规则生成数据
torch.linspace(1,10,4)
#生成满足均匀分布随机数
torch.rand(2,3)
#生成满足标准分布随机数
torch.randn(2,3)
#返回所给数据形状相同,值全为0的张量
torch.zeros_like(torch.rand(2,3))
2.4.3 修改Tensor形状
在处理数据、构建网络层等过程中,经常需要了解Tensor的形状、修改Tensor的形状。
与修改Numpy的形状类似,修改tenor的形状也有很多类似函数,具体可参考表2-2。 表2-2 为tensor常用修改形状的函数。
以下为一些实例
import torch
#生成一个形状为2x3的矩阵
x = torch.randn(2, 3)
#查看矩阵的形状
x.size() #结果为torch.Size([2, 3])
#查看x的维度
x.dim() #结果为2
#把x变为3x2的矩阵
x.view(3,2)
#把x展平为1维向量
y=x.view(-1)
y.shape
#添加一个维度
z=torch.unsqueeze(y,0)
#查看z的形状
z.size() #结果为torch.Size([1, 6])
#计算Z的元素个数
z.numel() #结果为6
【说明】torch.view与torch.reshpae的异同
①reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用。view()只可由torch.Tensor.view()来调用。
②对于一个将要被view的Tensor,新的size必须与原来的size与stride兼容。否则,在view之前必须调用contiguous()方法。
③同样也是返回与input数据量相同,但形状不同的tensor。若满足view的条件,则不会copy,若不满足,则会copy
④如果您只想重塑张量,请使用torch.reshape。 如果您还关注内存使用情况并希望确保两个张量共享相同的数据,请使用torch.view。
2.4.4 索引操作
Tensor的索引操作与Numpy类似,一般情况下索引结果与源数据共享内存。从tensor获取元素除了可以通过索引,也可借助一些函数,常用的选择函数,可参考表2-3。
表2-3 常用选择操作函数
以下为部分函数的实现代码:
import torch
#设置一个随机种子
torch.manual_seed(100)
#生成一个形状为2x3的矩阵
x = torch.randn(2, 3)
#根据索引获取第1行,所有数据
x[0,:]
#获取最后一列数据
x[:,-1]
#生成是否大于0的Byter张量
mask=x>0
#获取大于0的值
torch.masked_select(x,mask)
#获取非0下标,即行,列索引
torch.nonzero(mask)
#获取指定索引对应的值,输出根据以下规则得到
#out[i][j] = input[index[i][j]][j] # if dim == 0
#out[i][j] = input[i][index[i][j]] # if dim == 1
index=torch.LongTensor([[0,1,1]])
torch.gather(x,0,index)
index=torch.LongTensor([[0,1,1],[1,1,1]])
a=torch.gather(x,1,index)
#把a的值返回到一个2x3的0矩阵中
z=torch.zeros(2,3)
z.scatter_(1,index,a)
2.4.5 广播机制
本书1.7小节介绍了Numpy的广播机制,是向量运算的重要技巧。Pytorch也支持广播规则,以下通过几个示例进行说明。
import torch
import numpy as np
A = np.arange(0, 40,10).reshape(4, 1)
B = np.arange(0, 3)
#把ndarray转换为Tensor
A1=torch.from_numpy(A) #形状为4x1
B1=torch.from_numpy(B) #形状为3
#Tensor自动实现广播
C=A1+B1
#我们可以根据广播机制,手工进行配置
#根据规则1,B1需要向A1看齐,把B变为(1,3)
B2=B1.unsqueeze(0) #B2的形状为1x3
#使用expand函数重复数组,分别的4x3的矩阵
A2=A1.expand(4,3)
B3=B2.expand(4,3)
#然后进行相加,C1与C结果一致
C1=A2+B3
2.4.6 逐元素操作
与Numpy一样,tensor也有逐元素操作(element-wise),操作内容相似,但使用函数可能不尽相同。大部分数学运算都属于逐元操作,逐元素操作输入与输出的形状相同。,常见的逐元素操作,可参考表2-4。
表2-4常见逐元素操作
【说明】这些操作均创建新的tensor,如果需要就地操作,可以使用这些方法的下划线版本,例如abs_。
以下为部分逐元素操作代码实例。
import torch
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
#t+0.1*(t1/t2)
torch.addcdiv(t, 0.1, t1, t2)
#计算sigmoid
torch.sigmoid(t)
#将t限制在[0,1]之间
torch.clamp(t,0,1)
#t+2进行就地运算
t.add_(2)
2.4.7 归并操作
归并操作顾名思义,就是对输入进行归并或合计等操作,这类操作的输入输出形状一般不相同,而且往往是输入大于输出形状。归并操作可以对整个tensor,也可以沿着某个维度进行归并。常见的归并操作可参考表2-5。
表2-5 常见的归并操作
【说明】
归并操作一般涉及一个dim参数,指定沿哪个维进行归并。另一个参数是keepdim,说明输出结果中是否保留维度1,缺省情况是False,即不保留。
以下为归并操作的部分代码
import torch
#生成一个含6个数的向量
a=torch.linspace(0,10,6)
#使用view方法,把a变为2x3矩阵
a=a.view((2,3))
#沿y轴方向累加,即dim=0
b=a.sum(dim=0) #b的形状为[3]
#沿y轴方向累加,即dim=0,并保留含1的维度
b=a.sum(dim=0,keepdim=True) #b的形状为[1,3]
2.4.8 比较操作
比较操作一般进行逐元素比较,有些是按指定方向比较。常用的比较函数可参考表2-6。
表2-6 常用的比较函数
以下是部分函数的代码实现。
import torch
x=torch.linspace(0,10,6).view(2,3)
#求所有元素的最大值
torch.max(x) #结果为10
#求y轴方向的最大值
torch.max(x,dim=0) #结果为[6,8,10]
#求最大的2个元素
torch.topk(x,1,dim=0) #结果为[6,8,10],对应索引为tensor([[1, 1, 1]
2.4.9 矩阵操作
机器学习和深度学习中存在大量的矩阵运算,用的比较多的有两种,一种是逐元素乘法,另外一种是点积乘法。Pytorch中常用的矩阵函数可参考表2-7。
表2-7 常用矩阵函数
【说明】
①torch的dot与Numpy的dot有点不同,torch中dot对两个为1D张量进行点积运算,Numpy中的dot无此限制。
②mm是对2D的矩阵进行点积,bmm对含batch的3D进行点积运算。
③转置运算会导致存储空间不连续,需要调用contiguous方法转为连续。
import torch
a=torch.tensor([2, 3])
b=torch.tensor([3, 4])
torch.dot(a,b) #运行结果为18
x=torch.randint(10,(2,3))
y=torch.randint(6,(3,4))
torch.mm(x,y)
x=torch.randint(10,(2,2,3))
y=torch.randint(6,(2,3,4))
torch.bmm(x,y)
2.4.10 Pytorch与Numpy比较
Pytorch与Numpy有很多类似的地方,并且有很多相同的操作函数名称,或虽然函数名称不同但含义相同;当然也有一些虽然函数名称相同,但含义不尽相同。对此,有时很容易混淆,下面我们把一些主要的区别进行汇总,具体可参考表2-8。
表2-8 Pytorch与Numpy函数对照表
版权声明
本文为[sunshinecxm_BJTU]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_36744449/article/details/124027622
边栏推荐
- BottomSheetDialogFragment 与 ListView RecyclerView ScrollView 滑动冲突问题
- 微信小程序 使用wxml2canvas插件生成图片部分问题记录
- Binder机制原理
- Handlerthread principle and practical application
- Android面试计网面经大全【持续更新中。。。】
- PyTorch最佳实践和代码编写风格指南
- [recommendation of new books in 2021] enterprise application development with C 9 and NET 5
- Google AdMob advertising learning
- Thanos.sh灭霸脚本,轻松随机删除系统一半的文件
- 【动态规划】最长递增子序列
猜你喜欢
Encapsulate a set of project network request framework from 0
face_recognition人脸检测
机器学习笔记 一:学习思路
补补网络缺口
机器学习 三: 基于逻辑回归的分类预测
this. getOptions is not a function
机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
组件化学习(3)ARouter中的Path和Group注解
[recommendation of new books in 2021] practical IOT hacking
[recommendation for new books in 2021] professional azure SQL managed database administration
随机推荐
Encapsulate a set of project network request framework from 0
基于BottomNavigationView实现底部导航栏
[2021 book recommendation] effortless app development with Oracle visual builder
ThreadLocal,看我就够了!
Apprentissage par composantes
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
Itop4412 HDMI display (4.4.4_r1)
Handlerthread principle and practical application
MySQL notes 4_ Primary key auto_increment
MySQL5.7插入中文数据,报错:`Incorrect string value: ‘\xB8\xDF\xAE\xF9\x80 at row 1`
常用UI控件简写名
Using stack to realize queue out and in
补补网络缺口
ProcessBuilder工具类
Bottom navigation bar based on bottomnavigationview
launcher隐藏不需要显示的app icon
记录webView显示空白的又一坑
【2021年新书推荐】Red Hat Certified Engineer (RHCE) Study Guide
adb shell top 命令详解
Reading notes - activity