当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-08-10 05:29:00 【公众号学一点会一点】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
本文由 mdnice 多平台发布
边栏推荐
猜你喜欢

CSDN Markdown 之我见代码块 | CSDN编辑器测评

An article to master the entire JVM, JVM ultra-detailed analysis!!!

ThreadPoolExecutor线程池原理

Depth of carding: prevent model fitting method

pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)

论文精度 —— 2017 ACM《Globally and Locally Consistent Image Completion》

Order table delete, insert and search operations

如何用Apifox 的智能Mock功能?

Matlab simulation of multi-factor house price prediction based on BP neural network

大咖说·对话生态|当Confluent遇见云:实时流动的数据更有价值
随机推荐
I have a dream for Career .
MySQL simple tutorial
Get started with the OAuth protocol easily with a case
基于Qiskit——《量子计算编程实战》读书笔记(七)
Kubernetes:(十六)Ingress的概念和原理
【写下自用】每次都忘记如何train?记录如何训练自己的yolov5
常用工具系列 - 常用正则表达式
Become a language that hackers have to learn. Do you think it's okay after reading it?
AVL树的插入--旋转笔记
Order table delete, insert and search operations
Kubernetes:(十七)Helm概述、安装及配置
SQL Server query optimization
【格式转换】将JPEG图片批量处理为jpg格式
EasyGBS connects to mysql database and prompts "can't connect to mysql server", how to solve it?
aliases节点分析
pygame学习计划(1)
How does flinksql write that the value of redis has only the last field?
Nexus_Warehouse Type
接口调试还能这么玩?
OneFlow源码解析:算子指令在虚拟机中的执行