当前位置:网站首页>Senet | attention mechanism - source code + comments
Senet | attention mechanism - source code + comments
2022-04-22 03:47:00 【Graduate students are not late】
List of articles
1 SeNet Introduce
- SENet yes Squeeze-and-Excitation Networks For short , from Momenta Made by the company on 2017CVPR, In the paper SENet Won ImageNet The last session (ImageNet 2017) Image recognition champion
- SENet It's mainly about learning channel The correlation between , The attention to the channel was screened out , A little bit more computation , But the effect is better .
- Automatically obtain the importance of each feature channel through learning , Then according to this degree of importance to enhance the useful features , And suppress features that are not useful for the current task .
- Se The idea of module is simple , Easy to implement , It is easy to load into the existing network model framework .
2 SeNet advantage
- Add a few parameters , And can improve the accuracy of the model to a certain extent .
- Is in ResNet Strategy based on , Innovation is good , It is very suitable for creating new models with high accuracy .
- It is easy to insert into your own deep neural network model , To improve the accuracy of the model .
3 Se Specific introduction of the module
- Sequeeze: Along
Spatial dimension (channel)For feature compression , Turn each two-dimensional feature channel into a real number , This real number has a global receptive field to some extent , And the dimension of output matches the number of characteristic channels of input . It represents the global distribution of the response on the characteristic channel , And the layer close to the input can also obtain the global receptive field .
Specific operation ( It corresponds to the numbers in the code one by one ): For the original feature map50×512×7×7Conduct global average pooling, Then I got a50×512×1×1A feature map of size , This feature map has a global receptive field . - Excitation : Output
50×512×1×1Characteristics of figure , Through two fully connected neural networks , Finally, use a It is similar to that in cyclic neural network Door mechanism , Generate weights for each feature channel through parameters , The parameters are learned to explicitly model the correlation between feature channels ( The paper usessigmoid).50×512×1×1become50×512 / 16×1×1, Finally, restore it :50×512×1×1 - Feature recalibration : Use Excitation The result obtained is used as the weight , Then, it is weighted to... Channel by channel by multiplication U Of C On one channel (
50×512×1×1adoptexpand_asobtain50×512×7×7), Complete the recalibration of the original features in the channel dimension , And as the input data of the next level .
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)



4 Complete code
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # Global mean pooling The output is c×1×1
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # channel // reduction Represents channel compression
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), # Restore
nn.Sigmoid()
)
def init_weights(self):
for m in self.modules():
print(m) # Not running here
if isinstance(m, nn.Conv2d): # Judge type function ——:m yes nn.Conv2d Class? ?
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size() # 50×512×7×7
y = self.avg_pool(x).view(b, c) # ① maxpool After that :50×512×1×1 ② view Shape get 50×512
y = self.fc(y).view(b, c, 1, 1) # 50×512×1×1
return x * y.expand_as(x) # according to x.size To expand y
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SEAttention(channel=512, reduction=8) # Instantiation model se
output = se(input)
print(output.shape)
版权声明
本文为[Graduate students are not late]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204220345024898.html
边栏推荐
- Oracle architecture
- 偶然间用到的sql语句
- mysql中逻辑备份mysqldump的使用
- 安装班和免安装版
- Unittest unit test (I)
- There is no input method after win11 system starts up - the solution is effective through personal test
- Addition, deletion, modification and query of Oracle connection database
- GPU深度学习环境配置
- TP5 is making QR code
- DataGrip闪退,如何解决?MySQL
猜你喜欢

染色法判定二分图

Bubble ranking and running for president

Smart Life - how convenient is it to schedule smart home devices?

SaaS version goes online, and the applet application ecology goes further

MySQL下载

Machine learning theory (6): from logistic regression (logarithmic probability) method to SVM; Why is SVM the maximum interval classifier

Virtual DOM

Data mining series (2)_ The data mining plug-in of Excel connects to SQL server

容联七陌赋能企业智能化服务,重新定义客服价值

vscode 打造 shell 使用
随机推荐
Lesson 122 of serial 5 of Rasa dialogue robot: the actual operation of Rasa dialogue robot debugging project -- the whole life cycle debugging practice of bank financial dialogue robot - (I)
Redis database cluster (master-slave replication, sentinel, cluster)
Sharing: web design specification
VOS3000 8.05安装及源码
Summarize the differences between queue: work and queue: listen
Database management tools
[knowledge atlas] catalogue of financial securities knowledge atlas projects
Data mining series (2)_ The data mining plug-in of Excel connects to SQL server
Vscode shell
Stc8a8k64d4 (51 Series MCU) printf printing data abnormal problem
Docker starts the general solution of three warnings of redis official image
SeNet || 注意力机制——源代码+注释
MNIST recognition through back propagation
数据挖掘系列(2)_Excel的数据挖掘插件连接SQL Server
Implementation of small cases
JDBC uses precompiling to execute DQL statements, and the output is placeholder content. Why?
Class组件详解
Zabbix5 series - making topology map (XIII)
The third year after the epidemic: technical management and teamwork under the changing times!
Wonderful linkage! Principle and practice of openmldb pulsar connector