当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-08-10 05:29:00 【公众号学一点会一点】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
本文由 mdnice 多平台发布
边栏推荐
- What are the common commands of mysql
- pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
- FPGA工程师面试试题集锦41~50
- How cursors work in Pulsar
- Jenkins 如何玩转接口自动化测试?
- When oracle cdc, set the parallelism to 2 and the number of slots to 1, and the final task has only one tm. Is it because oracle does not support concurrency
- 【LeetCode】41. The first missing positive number
- pytorch 学习
- 基于BP神经网络的多因素房屋价格预测matlab仿真
- FPGA工程师面试试题集锦11~20
猜你喜欢
How to simulate the background API call scene, very detailed!
如何用Apifox 的智能Mock功能?
EasyGBS connects to mysql database and prompts "can't connect to mysql server", how to solve it?
【写下自用】每次都忘记如何train?记录如何训练自己的yolov5
scikit-learn机器学习 读书笔记(二)
接口调试还能这么玩?
Kubernetes:(十六)Ingress的概念和原理
What are the common commands of mysql
Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
随机推荐
Buu Web
MongoDB 基础了解(一)
Qiskit学习笔记(三)
Stacks and Queues | Valid parentheses, delete all adjacent elements in a string, reverse Polish expression evaluation, maximum sliding window, top K high frequency elements | leecode brush questions
聊聊 API 管理-开源版 到 SaaS 版
手把手带你写嵌入式物联网的第一个项目
Guys, is it normal that the oracle archive log grows by 3G in 20 minutes after running cdc?
Why are negative numbers in binary represented in two's complement form - binary addition and subtraction
Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
论文精度 —— 2016 CVPR 《Context Encoders: Feature Learning by Inpainting》
Matlab simulation of multi-factor house price prediction based on BP neural network
aliases node analysis
深度梳理:防止模型过拟合的方法汇总
FPGA engineer interview questions collection 21~30
Big guys, mysql cdc (2.2.1 and previous versions) sometimes has this situation since savepoint, is there anything wrong?
动手写prometheus的exporter-02-Counter(计数器)
Kubernetes:(十七)Helm概述、安装及配置
常用工具系列 - 常用正则表达式
AVL树的插入--旋转笔记
Flutter development: error The following assertion was thrown resolving an image codec: Solution for Unable to...