当前位置:网站首页>SeNet || 注意力机制——源代码+注释
SeNet || 注意力机制——源代码+注释
2022-04-22 03:45:00 【研究生不迟到】
1 SeNet介绍
- SENet是Squeeze-and-Excitation Networks的简称,由Momenta公司所作并发于2017CVPR,论文中的SENet赢得了ImageNet最后一届(ImageNet 2017)的图像识别冠军
- SENet主要是学习了channel之间的相关性,筛选出了针对通道的注意力,稍微增加了一点计算量,但是效果比较好。
- 通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征,并抑制对当前任务用处不大的特征。
- Se模块思想简单,易于实现,很容易加载到现有的网络模型框架中。
2 SeNet优点
- 增加少量的参数,并能够在一定程度上提高模型的准确率。
- 是在ResNet的基础上建立的策略,创新点好,很适合自己创作新模型刷高准确率。
- 很方便插入到自己的深度神经网络模型中,以提高模型的准确性。
3 Se模块的具体介绍
- Sequeeze:顺着
空间维度(channel)来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,且使得靠近输入的层也可以获得全局的感受野。
具体操作(和代码里面的数字是一一对应的):对原特征图50×512×7×7进行global average pooling,然后得到了一个50×512×1×1大小的特征图,这个特征图具有全局感受野。 - Excitation :输出的
50×512×1×1特征图,经过两个全连接神经网络,最后用一 个类似于循环神经网络中门的机制,通过参数来为每个特征通道生成权重,中参数被学习用来显式地建模特征通道间的相关性(论文中使用的是sigmoid)。50×512×1×1变成50×512 / 16×1×1,最后再还原回来:50×512×1×1 - 特征重标定:使用Excitation得到的结果作为权重,然后通过乘法逐通道加权到U的C个通道上(
50×512×1×1通过expand_as得到50×512×7×7), 完成在通道维度上对原始特征的重标定,并作为下一级的输入数据。
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)



4 完整代码
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局均值池化 输出的是c×1×1
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # channel // reduction代表通道压缩
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), # 还原
nn.Sigmoid()
)
def init_weights(self):
for m in self.modules():
print(m) # 没运行到这儿
if isinstance(m, nn.Conv2d): # 判断类型函数——:m是nn.Conv2d类吗?
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size() # 50×512×7×7
y = self.avg_pool(x).view(b, c) # ① maxpool之后得:50×512×1×1 ② view形状得到50×512
y = self.fc(y).view(b, c, 1, 1) # 50×512×1×1
return x * y.expand_as(x) # 根据x.size来扩展y
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SEAttention(channel=512, reduction=8) # 实例化模型se
output = se(input)
print(output.shape)
版权声明
本文为[研究生不迟到]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_42521185/article/details/124330333
边栏推荐
- Machine learning theory (6): from logistic regression (logarithmic probability) method to SVM; Why is SVM the maximum interval classifier
- [Ext JS ] 7.25.1 Form或者面板自动定位到错误的输入框
- Manage laravel queue with supervisor - configuration file supervisor conf
- Implementation of MySQL dblink and solution of @ problem in password
- 2021-10-21 software testing theory
- Deep learning and image recognition: principle and practice notes day_ twelve
- Full summary of 18 tax categories of tax law with memory tips
- Data mining series (2)_ The data mining plug-in of Excel connects to SQL server
- JDBC uses precompiling to execute DQL statements, and the output is placeholder content. Why?
- The wangeditor rich text editor uses and converts the content of the editor to JSON format
猜你喜欢

Zabbix5 series - monitor Huawei and H3C switches (snmpv2c / SNMPv3 / snmptrap) (II)

7-Zip exposes zero day security vulnerabilities! Provide administrator privileges to attackers by "impersonating file extensions"

What is the future direction of GPU?

Bubble ranking and running for president

Deep learning and image recognition: principle and practice notes day_ ten

Do447ansible tower navigation

Xiaomi and zhiting's smart cameras protect your family privacy

These good works of finclip hacker marathon competition, come and have a look

虚拟dom
![[raspberry pie C language development] experiment 12: pcf8591 analog-to-digital converter module](/img/26/0d1e1815c2ccc140eeabcfb82a8056.jpg)
[raspberry pie C language development] experiment 12: pcf8591 analog-to-digital converter module
随机推荐
Knowledge of power system
Detailed explanation of double pointers of ring linked list 1 and 2
Oracle database management
CentOS offline installation of MySQL
Ncurses installation package and PKG config information
Record the solution to the failure of configuring MySQL remote connection for ECS once
"Select command denied to user 'nature' @ '192.168.1.49' for table 'user_variables_by_thread" is solved in MySQL‘“
1 ActiveMQ introduction and installation
View application memory usage
Take a look at this guide when the Hackathon competition is going on
解决Flutter中ThemeData.primaryColor在AppBar等组件中不生效
JDBC使用预编译执行DQL语句输出都是占位符内容,这是为什么呢?
Use and principle of atomic class
docker启动redis官方镜像的三个警告的通用解决办法
安装班和免安装版
英语 | Day11、12 x 句句真研每日一句(意思群)
SaaS version goes online, and the applet application ecology goes further
Go gin framework configuration log output to file
pipeline
便利店卷疯了:便利蜂、罗森、易捷“激战”