当前位置:网站首页>Hands-on deep learning_ResNet
Hands-on deep learning_ResNet
2022-08-06 12:32:00 【CV Small Rookie】
The main feature of the residual network is the residual block,The residual block can be clearly expressed with a formula:
.
Why design such a function?The following example will answer this question for you
Suppose there is a specific class of neural network architectures
,它包括学习速率和其他超参数设置. 对于所有
,Some parameter sets exist(例如权重和偏置),These parameters can be obtained by training on a suitable dataset. 现在假设
是我们真正想要找到的函数,如果是
,那我们可以轻而易举的训练得到它,但通常我们不会那么幸运. 相反,We will try to find a function
,这是我们在F中的最佳选择. 例如,给定一个具有X特性和y标签的数据集,我们可以尝试通过解决以下优化问题来找到它:
在构建
的时候,我们的
There are usually two cases:Nested function classes and non-nested function classes.

可以看到,in a non-nested function class,虽然 The fitting ability is constantly improving,但是和
The distance is not getting closer.相反,If using nested class functions,On the basis of the original function, the fitting ability is continuously increased,always available
逼近的.
针对这一问题,何恺明等人提出了残差网络(ResNet). 它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计. 残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一. 于是,残差块(residual blocks)便诞生了,This design had a profound impact on how deep neural networks are built.
残差块 residual blocks
On the left is a normal block,On the right is a residual block,The residual structure is implemented by taking a path from the input and adding it directly to the normal block.Compared to the normal block, it is directly fitted by convolution
,The convolutional part of the residual block only needs to be fitted
,即可.This makes the network structure easier to optimize,It is also easier for input to propagate forward through such a skip link(In fact, backpropagation is also more optimized,我们之后再说)

我们具体来看一下 ResNet What the residual blocks in :ResNet 沿用了 VGG 完整的 3 × 3 卷积层设计. 残差块里首先有 2 个有相同输出通道数的 3 × 3 卷积层. Each convolutional layer is followed by one Batch Normalization 和 ReLU 激活函数. 然后我们通过跨层数据通路,跳过这 2 个卷积运算,将输入直接加在最后的 ReLU 激活函数前. 这样的设计要求 2 个卷积层的输出与输入形状一样,从而使它们可以相加. 如果想改变通道数,就需要引入一个额外的 1 × 1 卷积层来将输入变换成需要的形状后再做相加运算.如下图.

class Residual(nn.Module): #@save
def __init__(self, input_channels, num_channels,
use_1x1conv=False, strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, num_channels,
kernel_size=3, padding=1, stride=strides)
self.conv2 = nn.Conv2d(num_channels, num_channels,
kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, num_channels,
kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return F.relu(Y)ResNet 结构
ResNet-18为例:

ResNet 的前两层跟之前介绍的 GoogLeNet 中的一样: 在输出通道数为 64 、步长为 2 的 7 × 7卷积层后,The step length is 2 的 3 × 3 的maxpooling .不同之处在于 ResNet 每个卷积层后增加了Batch Normalization.
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))ResNet 则使用 4 个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块. 第一个模块的通道数同输入通道数一致. 由于之前已经使用了步幅为 2 的maxpooling,所以无须减小高和宽. 之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半.(这里的 4 I drew on the picture above)
def resnet_block(input_channels, num_channels, num_residuals,
first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(input_channels, num_channels,
use_1x1conv=True, strides=2))
else:
blk.append(Residual(num_channels, num_channels))
return blk
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))最后与 GoogLeNet 一样,在 ResNet Add a global average pooling layer to it,以及全连接层输出.
net = nn.Sequential(b1, b2, b3, b4, b5,
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(), nn.Linear(512, 10))ResNet Why it is possible to train very deep networks

In addition, when forward propagation or back propagation is performed:

边栏推荐
- Qt下编译警告unused parameter ,参数未使用
- KVM 简介
- PHP fopen写入文件内容
- MD5【加密以及解密】
- [极客大挑战 2019]PHP 1
- Draw timing diagrams with code!YYDS
- NAS 硬件采购配置记录
- 线程池需要关闭吗?使用线程池出现内存泄露的详细分析
- STM32 startup process - startup_xxxx.s file analysis (MDK and GCC dual environment)
- [Cloud native Kubernetes] Kubernetes container cloud platform deployment and operation
猜你喜欢
随机推荐
湖仓一体电商项目(一):项目背景和架构介绍
Absolutely!Ali people explain tens of billions of high-concurrency systems in 7 parts (full-color booklet open source)
STM32 startup process - startup_xxxx.s file analysis (MDK and GCC dual environment)
【SSL集训DAY1】B【动态规划】
[Cloud native Kubernetes] Kubernetes container cloud platform deployment and operation
Kotlin-inline:你需要知道的一切(Android)
链表 | 反转链表 | leecode刷题
leetcode 105. 从前序与中序遍历序列构造二叉树
SQL图解面试题:如何找到喜欢的电影?(表连接,语句执行顺序、模糊查询)
Qt下编译警告unused parameter ,参数未使用
A domestic placeholder service
从ADVANCE.AI 全球产品负责人周洪丞的发言中了解其如何通过产品赋能中国出海企业
PHP+HTML+MySQL实现登录报错
用代码画时序图!YYDS
【Web3 系列开发教程——创建你的第一个 NFT(6)】为 NFT 设置价格
Web网页端IM产品RainbowChat-Web的v4.1版已发布
线程池需要关闭吗?使用线程池出现内存泄露的详细分析
NAS 系统调研
事件
“恰好装满求最值”背包问题的初始化解析









