当前位置:网站首页>Pytorch学习记录(九):Pytorch中卷积神经网络
Pytorch学习记录(九):Pytorch中卷积神经网络
2022-04-23 05:43:00 【左小田^O^】
卷积神经网络中所有的层结构都可以通过 nn这个包调用。
1.卷积层nn.Conv2d()
卷积在 pytorch 中有两种方式,一种是torch.nn.Conv2d()
,一种是 torch.nn.functional.conv2d()
,这两种形式本质都是使用一个卷积操作。
这两种形式的卷积对于输入的要求都是一样的,首先需要输入是一个 torch.autograd.Variable()
的类型,大小是 (batch, channel, H, W)
,其中 batch
表示输入的一批数据的数目,第二个是输入的通道数,一般一张彩色的图片是 3,灰度图是 1,而卷积网络过程中的通道数比较大,会出现几十到几百的通道数,H 和 W 表示输入图片的高度和宽度,比如一个 batch 是 32 张图片,每张图片是 3 通道,高和宽分别是 50 和 100,那么输入的大小就是 (32, 3, 50, 100)
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
nn.Conv2d()
就是PyTorch中的卷积模块了,
里面常用的参数有5个,分别是 in_channels,out_channels,kernel_size,stride,padding,
除此之外还有参数dilation,groups,bias
。下面来解释每个参数的含义。
in_channels
对应的是输入数据体的深度;out_channels
表示输出数据体的深度;
kernel_size
表示滤波器(卷积核)的大小,可以使用一个数字来表示高和宽相同的卷积核,比如 kernel_size=3,也可以使用不同的数字来表示高和宽不同的卷积核,比如 kernel_size=(3,2);
stride
表示滑动的步长;
padding=0
表示四周不进行零填充,
padding=1
表示四周进行1个像素点的零填充;
bias
是一个布尔值,默认 bias=True
,表示使用偏置;
groups
表示输出数据体深度上和输入数据体深度上的联系,默认 groups=1
,也就是所有的输出和输入都是相关联的,如果 groups=2,这表示输入的深度被分割成两份,输出的深度也被分割成两份,它们之间分别对应起来,所以要求输出和输入都必须要能被 groups整除;
dilation
表示卷积对于输入数据体的空间间
隔,默认 dilation=1,
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
im = Image.open('cat.png').convert('L') # 读入一张灰度图片
im = np.array(im, dtype='float32') # 将其转换为一个矩阵
plt.imshow(im.astype('uint8'), cmap='gray') # 显示图像,输出为灰度图
plt.show()
# 将图片矩阵转化为 pytorch tensor,并适配卷积输入的要求
im = torch.from_numpy(im.reshape((1, 1, im.shape[0], im.shape[1])))
# 使用 nn.Conv2d
conv1 = nn.Conv2d(1, 1, 3, bias=False) # 定义卷积,输入数据深度1,输出深度1,卷积核大小3*3,不设置偏置
sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') # 定义边缘检测算子大小
sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) # 1个卷积核,深度为1,大小3*3
conv1.weight.data = torch.from_numpy(sobel_kernel) # 给卷积的卷积核赋值
edge1 = conv1(Variable(im)) # 卷积作用在图片上
edge1 = edge1.data.squeeze().numpy() # 将输出转化为图片格式
# 显示边缘检测结果
plt.imshow(edge1, cmap='gray')
plt.show()
# 使用 F.conv2d
sobel_kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype='float32') # 定义轮廓检测算子
sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) # 适配卷积的输入输出
weight = Variable(torch.from_numpy(sobel_kernel))
edge2 = F.conv2d(Variable(im), weight) # 作用在图片上
edge2 = edge2.data.squeeze().numpy() # 将输出转换为图片的格式
plt.imshow(edge2, cmap='gray')
2.池化层
卷积网络中另外一个非常重要的结构就是池化,这是利用了图片的下采样不变性,即一张图片变小了还是能够看出了这张图片的内容,而使用池化层能够将图片大小降低,非常好地提高了计算效率,同时池化层也没有参数。池化的方式有很多种,比如最大值池化,均值池化等等,在卷积网络中一般使用最大值池化。
在 pytorch 中最大值池化的方式也有两种,一种是 nn.MaxPool2d()
,一种是torch.nn.functional.max_pool2d()
,他们对于图片的输入要求跟卷积对于图片的输入要求是一样了,就不再赘述,下面我们也举例说明
nn.MaxPool2d()
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
nn.MaxPool2d()
表示网络中的最大值池化,其中的参数有kernel_size、stride、padding、dilation、return_indices、ceil_mode
下面解释一下它们各自的含义。
· kernel_size,stride,padding,dilation之前卷积层已经介绍过了,是相同的含义;
· return_indices
表示是否返回最大值所处的下标,默认return_indices=False
;
· ceil_mode
表示使用一些方格代替层结构,默认ceil_mode=False
,一般都不会设置这些参数。
nn.AvgPool2d()
torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False),count_include_pad=True
· nn.AvgPool2d()
表示均值池化,里面的参数和
nn.MaxPool2d()类似,但多一个参数count_incl ude_pad
,这个参数表示计算均值的时候是否包含零填充,默认count_include_pad=True。
# 使用 nn.MaxPool2d
pool1 = nn.MaxPool2d(2, 2)
print('before max pool, image shape: {} x {}'.format(im.shape[2], im.shape[3]))
small_im1 = pool1(Variable(im))
small_im1 = small_im1.data.squeeze().numpy()
print('after max pool, image shape: {} x {} '.format(small_im1.shape[0], small_im1.shape[1]))
before max pool, image shape: 224 x 224
after max pool, image shape: 112 x 112
可以看到图片的大小减小了一半,但是图片不变
一般使用nn.MaxPool2d()
3.提取层结构
对于一个给定的模型,如果不想要模型中所有的层结构,只希望能够提取网络中的某一层或者几层,应该如何来实现呢?
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/120116824
边栏推荐
猜你喜欢
Radar equipment (greedy)
JVM family (4) -- memory overflow (OOM)
基于ssm 包包商城系统
Navicate连接oracle(11g)时ORA:28547 Connection to server failed probable Oeacle Net admin error
mysql sql优化之Explain
PreparedStatement防止SQL注入
建表到页面完整实例演示—联表查询
SQL statement simple optimization
Manually delete registered services on Eureka
opensips(1)——安装opensips详细流程
随机推荐
Introduction to data security -- detailed explanation of database audit system
Map对象 map.get(key)
io. lettuce. core. RedisCommandExecutionException: ERR wrong number of arguments for ‘auth‘ command
2 - principes de conception de logiciels
Mysql 查询使用\G,列转行
umi官网yarn create @umijs/umi-app 报错:文件名、目录名或卷标语法不正确
C language - Spoof shutdown applet
The 8th Blue Bridge Cup 2017 - frog jumping cup
‘EddiesObservations‘ object has no attribute ‘filled‘
多线程与高并发(1)——线程的基本知识(实现,常用方法,状态)
PreparedStatement防止SQL注入
Meta annotation (annotation of annotation)
基于ssm 包包商城系统
类的加载与ClassLoader的理解
关于二叉树的遍历
Map object map get(key)
Font shape `OMX/cmex/m/n‘ in size <10.53937> not available (Font) size <10.95> substituted.
Typescript interface & type rough understanding
AcWing 836. Merge set (merge set)
MySql基础狂神说