当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-08-10 05:29:00 【公众号学一点会一点】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
本文由 mdnice 多平台发布
边栏推荐
- Concurrency tool class - introduction and use of CountDownLatch, CyclicBarrier, Semaphore, Exchanger
- 接口文档进化图鉴,有些古早接口文档工具,你可能都没用过
- FPGA工程师面试试题集锦31~40
- Important transformation and upgrading
- ThreadPoolExecutor thread pool principle
- How to simulate the background API call scene, very detailed!
- Depth of carding: prevent model fitting method
- Guys, the test in the idea uses FlinkCDC SQL to read Mysql data and write it into Kafka. The code creates
- 自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略
- 利用PyQt5制作YOLOv5的GUI界面
猜你喜欢

【Pei Shu Theorem】CF1055C Lucky Days

每周推荐短视频:探索AI的应用边界

Pony语言学习(八):引用能力(Reference Capabilities)

基本比例尺标准分幅编号流程

Stacks and Queues | Valid parentheses, delete all adjacent elements in a string, reverse Polish expression evaluation, maximum sliding window, top K high frequency elements | leecode brush questions

summer preschool assignments

Error when installing oracle rac 11g and executing root.sh

Get started with the OAuth protocol easily with a case

Kubernetes:(十七)Helm概述、安装及配置

Pony语言学习(七)——表达式(Expressions)语法(单篇向)
随机推荐
pytorch框架学习(6)训练一个简单的自己的CNN (三)细节篇
How to improve product quality from the code layer
SEO搜索引擎优化
MySql's json_extract function processes json fields
论文精度 —— 2017 ACM《Globally and Locally Consistent Image Completion》
How cursors work in Pulsar
应用在智能触摸遥控器中的触摸芯片
You can‘t specify target table ‘kms_report_reportinfo‘ for update in FROM clause
FPGA工程师面试试题集锦1~10
Order table delete, insert and search operations
几种绘制时间线图的方法
接口调试还能这么玩?
如何模拟后台API调用场景,很细!
基于Qiskit——《量子计算编程实战》读书笔记(七)
pytorch框架学习(1)网络的简单构建
How does flinksql write that the value of redis has only the last field?
Pony语言学习(六):Struct, Type Alias, Type Expressions
FPGA engineer interview questions collection 31~40
Hezhou ESP32C3 +1.8"tft network clock under Arduino framework
常用工具系列 - 常用正则表达式