当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-08-10 05:29:00 【公众号学一点会一点】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
本文由 mdnice 多平台发布
边栏推荐
- 【yolov5训练错误】WARNING: Ignoring corrupted image
- Kubernetes:(十六)Ingress的概念和原理
- 25张炫酷交互图表,一文入门Plotly
- 在yolov5的网络结构中添加注意力机制模块
- The time for flinkcdc to read pgsql is enlarged. Does anyone know what happened? gmt_create':1
- SSM框架整合实例
- pytorch框架学习(6)训练一个简单的自己的CNN (三)细节篇
- How cursors work in Pulsar
- 再肝3天,整理了90个 NumPy 例子,不能不收藏!
- 论文精度 —— 2016 CVPR 《Context Encoders: Feature Learning by Inpainting》
猜你喜欢
How to simulate the background API call scene, very detailed!
pytest测试框架
Qiskit 学习笔记2
几种绘制时间线图的方法
Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
How to use Apifox's Smart Mock function?
如何在报表控件FastReport.NET中连接XLSX 文件作为数据源?
基于Qiskit——《量子计算编程实战》读书笔记(一)
基本比例尺标准分幅编号流程
聊聊 API 管理-开源版 到 SaaS 版
随机推荐
【Static proxy】
scikit-learn机器学习 读书笔记(一)
PyTorch 入门之旅
Pony语言学习(七)——表达式(Expressions)语法(单篇向)
在yolov5的网络结构中添加注意力机制模块
论文精读 —— 2021 CVPR《Progressive Temporal Feature Alignment Network for Video Inpainting》
Interface debugging also can play this?
基于Qiskit——《量子计算编程实战》读书笔记(六)
The sword refers to Offer 033. Variation array
pytorch learning
Pony语言学习(九)——泛型与模式匹配(终章)
基于BP神经网络的多因素房屋价格预测matlab仿真
Error when installing oracle rac 11g and executing root.sh
Pony语言学习(八):引用能力(Reference Capabilities)
aliases node analysis
Read the excerpt notes made by dozens of lightweight target detection papers for literacy
Jenkins 如何玩转接口自动化测试?
【YOLOv5训练错误】权重文件出错?
When oracle cdc, set the parallelism to 2 and the number of slots to 1, and the final task has only one tm. Is it because oracle does not support concurrency
聊聊 API 管理-开源版 到 SaaS 版