当前位置:网站首页>Pytorch学习记录(九):Pytorch中卷积神经网络
Pytorch学习记录(九):Pytorch中卷积神经网络
2022-04-23 05:43:00 【左小田^O^】
卷积神经网络中所有的层结构都可以通过 nn这个包调用。
1.卷积层nn.Conv2d()
卷积在 pytorch 中有两种方式,一种是torch.nn.Conv2d()
,一种是 torch.nn.functional.conv2d()
,这两种形式本质都是使用一个卷积操作。
这两种形式的卷积对于输入的要求都是一样的,首先需要输入是一个 torch.autograd.Variable()
的类型,大小是 (batch, channel, H, W)
,其中 batch
表示输入的一批数据的数目,第二个是输入的通道数,一般一张彩色的图片是 3,灰度图是 1,而卷积网络过程中的通道数比较大,会出现几十到几百的通道数,H 和 W 表示输入图片的高度和宽度,比如一个 batch 是 32 张图片,每张图片是 3 通道,高和宽分别是 50 和 100,那么输入的大小就是 (32, 3, 50, 100)
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
nn.Conv2d()
就是PyTorch中的卷积模块了,
里面常用的参数有5个,分别是 in_channels,out_channels,kernel_size,stride,padding,
除此之外还有参数dilation,groups,bias
。下面来解释每个参数的含义。
in_channels
对应的是输入数据体的深度;out_channels
表示输出数据体的深度;
kernel_size
表示滤波器(卷积核)的大小,可以使用一个数字来表示高和宽相同的卷积核,比如 kernel_size=3,也可以使用不同的数字来表示高和宽不同的卷积核,比如 kernel_size=(3,2);
stride
表示滑动的步长;
padding=0
表示四周不进行零填充,
padding=1
表示四周进行1个像素点的零填充;
bias
是一个布尔值,默认 bias=True
,表示使用偏置;
groups
表示输出数据体深度上和输入数据体深度上的联系,默认 groups=1
,也就是所有的输出和输入都是相关联的,如果 groups=2,这表示输入的深度被分割成两份,输出的深度也被分割成两份,它们之间分别对应起来,所以要求输出和输入都必须要能被 groups整除;
dilation
表示卷积对于输入数据体的空间间
隔,默认 dilation=1,
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
im = Image.open('cat.png').convert('L') # 读入一张灰度图片
im = np.array(im, dtype='float32') # 将其转换为一个矩阵
plt.imshow(im.astype('uint8'), cmap='gray') # 显示图像,输出为灰度图
plt.show()
# 将图片矩阵转化为 pytorch tensor,并适配卷积输入的要求
im = torch.from_numpy(im.reshape((1, 1, im.shape[0], im.shape[1])))
# 使用 nn.Conv2d
conv1 = nn.Conv2d(1, 1, 3, bias=False) # 定义卷积,输入数据深度1,输出深度1,卷积核大小3*3,不设置偏置
sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') # 定义边缘检测算子大小
sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) # 1个卷积核,深度为1,大小3*3
conv1.weight.data = torch.from_numpy(sobel_kernel) # 给卷积的卷积核赋值
edge1 = conv1(Variable(im)) # 卷积作用在图片上
edge1 = edge1.data.squeeze().numpy() # 将输出转化为图片格式
# 显示边缘检测结果
plt.imshow(edge1, cmap='gray')
plt.show()
# 使用 F.conv2d
sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') # 定义轮廓检测算子
sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) # 适配卷积的输入输出
weight = Variable(torch.from_numpy(sobel_kernel))
edge2 = F.conv2d(Variable(im), weight) # 作用在图片上
edge2 = edge2.data.squeeze().numpy() # 将输出转换为图片的格式
plt.imshow(edge2, cmap='gray')
2.池化层
卷积网络中另外一个非常重要的结构就是池化,这是利用了图片的下采样不变性,即一张图片变小了还是能够看出了这张图片的内容,而使用池化层能够将图片大小降低,非常好地提高了计算效率,同时池化层也没有参数。池化的方式有很多种,比如最大值池化,均值池化等等,在卷积网络中一般使用最大值池化。
在 pytorch 中最大值池化的方式也有两种,一种是 nn.MaxPool2d()
,一种是torch.nn.functional.max_pool2d()
,他们对于图片的输入要求跟卷积对于图片的输入要求是一样了,就不再赘述,下面我们也举例说明
nn.MaxPool2d()
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
nn.MaxPool2d()
表示网络中的最大值池化,其中的参数有kernel_size、stride、padding、dilation、return_indices、ceil_mode
下面解释一下它们各自的含义。
· kernel_size,stride,padding,dilation之前卷积层已经介绍过了,是相同的含义;
· return_indices
表示是否返回最大值所处的下标,默认return_indices=False
;
· ceil_mode
表示使用一些方格代替层结构,默认ceil_mode=False
,一般都不会设置这些参数。
nn.AvgPool2d()
torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False),count_include_pad=True
· nn.AvgPool2d()
表示均值池化,里面的参数和
nn.MaxPool2d()类似,但多一个参数count_incl ude_pad
,这个参数表示计算均值的时候是否包含零填充,默认count_include_pad=True。
# 使用 nn.MaxPool2d
pool1 = nn.MaxPool2d(2, 2)
print('before max pool, image shape: {} x {}'.format(im.shape[2], im.shape[3]))
small_im1 = pool1(Variable(im))
small_im1 = small_im1.data.squeeze().numpy()
print('after max pool, image shape: {} x {} '.format(small_im1.shape[0], small_im1.shape[1]))
before max pool, image shape: 224 x 224
after max pool, image shape: 112 x 112
可以看到图片的大小减小了一半,但是图片不变
一般使用nn.MaxPool2d()
3.提取层结构
对于一个给定的模型,如果不想要模型中所有的层结构,只希望能够提取网络中的某一层或者几层,应该如何来实现呢?
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/120116824
边栏推荐
- ‘EddiesObservations‘ object has no attribute ‘filled‘
- Conda 虚拟环境管理(创建、删除、克隆、重命名、导出和导入)
- 线性规划问题中可行解,基本解和基本可行解有什么区别?
- PreparedStatement防止SQL注入
- MySQL realizes master-slave replication / master-slave synchronization
- Dwsurvey is an open source questionnaire system. Solve the problem that cannot be run and modify the bug.
- Breadth first search topics (BFS)
- Split and merge multiple one-dimensional arrays into two-dimensional arrays
- SQL基础:初识数据库与SQL-安装与基本介绍等—阿里云天池
- Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
猜你喜欢
MySQL lock mechanism
Ora: 28547 connection to server failed probable Oracle net admin error
基于ssm 包包商城系统
Flutter nouvelle génération de rendu graphique Impeller
Package mall system based on SSM
Software architecture design - software architecture style
Anaconda
域内用户访问域外samba服务器用户名密码错误
‘EddiesObservations‘ object has no attribute ‘filled‘
PyEMD安装及简单使用
随机推荐
C language - Spoof shutdown applet
引航成长·匠心赋能——YonMaster开发者培训领航计划全面开启
MDN文档里面入参写法中括号‘[]‘的作用
Typescript interface & type rough understanding
Radar equipment (greedy)
JVM系列(4)——内存溢出(OOM)
RedHat6之smb服务访问速度慢解决办法记录
数据处理之Numpy常用函数表格整理
‘EddiesObservations‘ object has no attribute ‘filled‘
容器
Common status codes
实体中list属性为空或者null,设置为空数组
SQL注入
Dwsurvey is an open source questionnaire system. Solve the problem that cannot be run and modify the bug.
Character recognition easyocr
opensips(1)——安装opensips详细流程
多个一维数组拆分合并为二维数组
mysql-触发器、存储过程、存储函数
Breadth first search topics (BFS)
K/3 WISE系统考勤客户端日期只能选到2019年问题