当前位置:网站首页>NanoDet代码逐行精读与修改(三)辅助训练模块AGM
NanoDet代码逐行精读与修改(三)辅助训练模块AGM
2022-08-09 03:34:00 【HNU跃鹿战队】
笔者已经为nanodet增加了非常详细的注释,代码请戳此仓库:nanodet_detail_notes: detail every detail about nanodet 。
此仓库会跟着文章推送的节奏持续更新!
目录
3. Assign Guidance Module
AGM负责生成cost矩阵,进行标签分配,相当于一个非常轻量的KD模型中的教师,使得head能更好的学习bbox的回归与分类。
3.1. 参数和初始化
class SimpleConvHead(nn.Module):
def __init__(
self,
num_classes,
input_channel, # 输入的特征通道数
feat_channels=256, # AGM内部的特征通道数
stacked_convs=4, # 使用四层卷积
# 默认三个尺度,但是PAN中添加了额外层,配置文件可以看到是[8,16,32,64]
strides=[8, 16, 32],
conv_cfg=None,
# 使用group norm作为归一化层,效果优于BN
norm_cfg=dict(type="GN", num_groups=32, requires_grad=True),
activation="LeakyReLU",
# 配置文件中的默认参数是7
reg_max=16,
**kwargs
):
super(SimpleConvHead, self).__init__()
self.num_classes = num_classes
self.in_channels = input_channel
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.reg_max = reg_max
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.activation = activation
self.cls_out_channels = num_classes
self._init_layers()
self.init_weights()
使用了GFL的检测头在输出位置时将会输出4*(reg_max+1)个值,每条边都有reg_max+1个输出用于建模其分布,即用reg_max+1个离散值的积分来得到最终的位置预测。至于为什么是reg_max1而不是reg_max,请看下图:
关于DFL的部分解释
因此reg_max=7实际上是根据用于检测的feature map相对于原图的上采样率计算得到的。
这部分对于稍后要介绍的 NanoDet-plus head的回归分支也是同理。
3.2. 构建卷积层
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
# range从0开始索引到stacked_convs-1
for i in range(self.stacked_convs):
# 第一层需要和输入对齐通道数,之后始终保持为feat_channels
chn = self.in_channels if i == 0 else self.feat_channels
# 分类分支
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
activation=self.activation,
)
)
# 回归分支
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
activation=self.activation,
)
)
# 最后加上分类头
self.gfl_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1
)
# 回归头的输出为4*(reg_max+1),解释见 3.1
self.gfl_reg = nn.Conv2d(
self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1
)
# 用于缩放回归出的bbox的系数,这是一个可学习的参数
self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
Scale的构成非常简单,就是乘上一个数值,使得回归出的框更加精确:
class Scale(nn.Module):
"""
A learnable scale parameter
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale
3.3. forward()
# 全部采用normal init,没什么好说的
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = -4.595
normal_init(self.gfl_cls, std=0.01, bias=bias_cls)
normal_init(self.gfl_reg, std=0.01)
def forward(self, feats):
outputs = []
for x, scale in zip(feats, self.scales):
cls_feat = x
reg_feat = x
# 对于来自PAN的每一层输入,计算class分支
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
# 计算regression分支
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
# 得到类别分数
cls_score = self.gfl_cls(cls_feat)
# 得到回归分布并进行缩放
bbox_pred = scale(self.gfl_reg(reg_feat)).float()
# 拼接得到输出
output = torch.cat([cls_score, bbox_pred], dim=1)
# 追加到aux_pred后面
outputs.append(output.flatten(start_dim=2))
# 整理对齐维度,在之后的dsl_assigner中我们会详细介绍如何处理来自AGM和head的输出
outputs = torch.cat(outputs, dim=2).permute(0, 2, 1)
return outputs
了解了AGM的输出后,第四部分会介绍本文最重要的Dynamic soft label assigner这个模块了。
边栏推荐
- 了解CV和RoboMaster视觉组(五)local-distribution汇聚方法
- C18-PEG- ALD批发_C18-PEG-CHO_C18-PEG-醛基
- 【问题记录】pip 安装报错 Failed to establish a new connection
- The condition variable condition_variable implements thread synchronization
- Win7电脑无法进入睡眠模式?
- Embedded system driver advanced [3] - __ID matching and device tree matching under platform bus driver development
- win10怎么安装.net framework 3.5?
- 下秒数据CEO蔡致暖受邀参加联合数据举办《数据要素加速跑》线上沙龙
- 全链路UI设计笔记
- static成员及代码块
猜你喜欢
Error detected while processing /home/test/.vim/plugin/visualmark.vim
Leetcode Brushing Questions - 148. Sort Linked List
Kaggle(六)特征衍生技术 特征聚合
Linux安装MySQL8
SQL注入(2)
光刻机随感
leetcode-23. Merge K ascending linked lists
从暴力递归到动态规划leetcode第62题:不同路径
Embedded system driver advanced [2] - platform bus driver development _ basic framework
One Pass 1258 - Digital Pyramid (Dynamic Programming)
随机推荐
23 Lectures on Disassembly of Multi-merchant Mall System Functions-Platform Distribution Level
365 days challenge LeetCode1000 topic - Day 051 special binary sequence partition
下秒数据CEO蔡致暖受邀参加联合数据举办《数据要素加速跑》线上沙龙
项目中'说到做不到'的个人分析
条件变量condition_variable实现线程同步
29 机器学习中常常提到的正则化到底是什么意思
hcip MPLS 实验
H264之sps解析分辨率
SQL注入(1)
06 Dynamic memory
07.1 类的的补充
交换VLAN实验
Embedded system driver advanced [2] - platform bus driver development _ basic framework
Arrays and slices
EventLoop同步异步,宏任务微任务笔记
30 norm
甲乙丙丁加工零件,加工的总数是370, 如果甲加工的零件数多10,如果乙加工的零件数少20,如果丙加工的 零件数乘以2,如果丁加工的零件数除以2,四个人的加工数量相等,求甲乙丙丁各自加工多少个零件?
SQL注入(3)
智能计数器控制板的功能及应用有哪些?
[Graphics] 19 Lighting model (four, Blinn-Phong lighting model)