当前位置:网站首页>自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略
自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略
2022-08-10 05:29:00 【Time.Xu】
代码地址:GitHub - GOATmessi7/ASFF: yolov3 with mobilenet v2 and ASFF
https://github.com/GOATmessi7/ASFF
【侵删】部分来源于其他博客,文末有链接
ASFF的核心思想是通过学习自适应地调整各个尺度特征在融合时地空间权重。
特征金字塔FPN是解决目标检测中尺度变化问题的常用方法。
为什么使用特征金字塔呢,是因为CNN对物体的平移是保持不变的,而对于物体的尺度变换是无法处理的,所以使用特征金字塔进行处理。
这里有一个语义信息和位置信息的问题/感受野的问题:
特征图越小(越顶层)、提取的信息就越抽象、感受野就越大,那么对于小的特征可能就检测不到。
特征图越大(越底层),感受野小也就更容易检测到小特征和小物体。但用于提取的信息不够抽象,即FeatureMap的语义信息差,物体容易被错分。
总结一下就是:高层的特征语义信息多,低层的特征位置信息多
然而,不同特征尺度之间的不一致性是基于特征金字塔的single-shot detectors【我认为single-shot就是之后的one-stage】的主要限制。它学习了空间过滤冲突信息的方法来抑制不一致性,从而提高了特征的尺度不变性,并引入了几乎自由的推理开销。
为了实现尺度不变性,最新的先进detectors构建特征金字塔或多层特征塔。它利用前向传递中计算的不同层的多尺度特征图来预测不同大小的目标。然而,由于浅层特征图包含的语义信息不足,这种自底向上的路径在小实例上的准确性较低。特征金字塔网络(FPN)将骨干网模型中两个相邻的特征层次顺序地结合在一起,采用自顶向下的路径和横向连接。将低分辨率、语义强的特征向上采样,并与高分辨率、语义弱的特征相结合,构建一个在所有层次上共享丰富语义的特征金字塔。FPN和其他类似的自顶向下结构简单而有效,但仍有很大的改进空间。事实上,许多具有先进跨尺度连接的最新模型通过加强特征融合,来提高精度。除了人工设计的融合结构,NAS- fpn应用神经架构搜索(NAS)技术来追求更好的架构,在许多主干网上产生显著的改进。虽然这些先进的研究提供了更强大的特征金字塔,但它们仍然为尺度不变预测留下了空间。SNIP[33,34]给出了一些证据,它采用了一种尺度归一化方法,在多尺度图像金字塔的每个图像尺度上有选择地训练和推断出适当大小的目标,从而进一步改进了基于金字塔特征的多尺度测试检测器的结果。然而,图像金字塔解决方案大大增加了推理时间,这使得它们不适用于实际应用。
同时,与图像金字塔相比,特征金字塔的一个主要缺点是不同尺度的不一致性。在检测具有特征金字塔的目标时,采用启发式引导的特征选择:大实例通常与上特征图关联,小实例通常与下特征图关联。当某一物体在某一层次的特征图中被指定为正时,其他层次特征图中的相应区域被视为背景。因此,如果一幅图像同时包含大小物体,那么不同层次的特征之间的冲突往往会占据特征金字塔的主要部分。这种不一致性干扰了训练过程中的梯度计算,降低了特征金字塔的有效性。一些模型采用了几种尝试性的策略来解决这个问题。一些学者将相邻层次上特征图对应的区域设置为忽略区域(即零梯度),但这种缓解可能会增加邻近特征水平的较差预测。一些学者创建多个特定于尺度的分支,这些分支具有不同的接受字段,用于尺度感知训练和推理。它脱离了特征金字塔以避免不一致,但也没有重用其高分辨率地图,限制了小实例的准确性。
ASFF方法使网络能够直接学习如何在空间上过滤其他层次的特征,以便只保留有用的信息用于组合。对于某一层次上的特征,首先将其他层次上的特征整合并调整到相同的分辨率,然后训练得到最优融合。在每个空间位置上,不同层次的特征自适应融合,即某些特征在该位置携带矛盾的信息,可能被过滤掉,而另一些特征可能以更具判别性的线索占主导地位。ASFF算法具有以下几个优点:(1)由于搜索最优融合的操作是差分的,可以方便地在反向传播中学习;(2)与主干模型无关,适用于具有特征金字塔结构的single-shot detector;(3)实现简单,增加的计算成本很小。
放在common里
import torch.nn.functional as F
class ASFFV5(nn.Module):
def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
"""
ASFF version for YoloV5 .
different than YoloV3
multiplier should be 1, 0.5
which means, the channel of ASFF can be
512, 256, 128 -> multiplier=1
256, 128, 64 -> multiplier=0.5
For even smaller, you need change code manually.
"""
super(ASFFV5, self).__init__()
self.level = level
self.dim = [int(1024 * multiplier), int(512 * multiplier),
int(256 * multiplier)]
# print(self.dim)
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(int(512 * multiplier), self.inter_dim, 3, 2)
self.stride_level_2 = Conv(int(256 * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(
1024 * multiplier), 3, 1)
elif level == 1:
self.compress_level_0 = Conv(
int(1024 * multiplier), self.inter_dim, 1, 1)
self.stride_level_2 = Conv(
int(256 * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(512 * multiplier), 3, 1)
elif level == 2:
self.compress_level_0 = Conv(
int(1024 * multiplier), self.inter_dim, 1, 1)
self.compress_level_1 = Conv(
int(512 * multiplier), self.inter_dim, 1, 1)
self.expand = Conv(self.inter_dim, int(
256 * multiplier), 3, 1)
# when adding rfb, we use half number of channels to save memory
compress_c = 8 if rfb else 16
self.weight_level_0 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_levels = Conv(
compress_c * 3, 3, 1, 1)
self.vis = vis
def forward(self, x): # l,m,s
"""
# 128, 256, 512
512, 256, 128
from small -> large
"""
x_level_0 = x[2] # l
x_level_1 = x[1] # m
x_level_2 = x[0] # s
# print('x_level_0: ', x_level_0.shape)
# print('x_level_1: ', x_level_1.shape)
# print('x_level_2: ', x_level_2.shape)
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(
x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=4, mode='nearest')
x_level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(
x_level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
# print('level: {}, l1_resized: {}, l2_resized: {}'.format(self.level,
# level_1_resized.shape, level_2_resized.shape))
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
# print('level_0_weight_v: ', level_0_weight_v.shape)
# print('level_1_weight_v: ', level_1_weight_v.shape)
# print('level_2_weight_v: ', level_2_weight_v.shape)
levels_weight_v = torch.cat(
(level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
level_1_resized * levels_weight[:, 1:2, :, :] + \
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
放在yolo.py中
class ASFF_Detect(nn.Module): #add ASFFV5 layer and Rfb
stride = None # strides computed during build
onnx_dynamic = False # ONNX export parameter
def __init__(self, nc=80, anchors=(), ch=(), multiplier=0.5,rfb=False,inplace=True): # detection layer
super(ASFF_Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
self.l0_fusion = ASFFV5(level=0, multiplier=multiplier,rfb=rfb)
self.l1_fusion = ASFFV5(level=1, multiplier=multiplier,rfb=rfb)
self.l2_fusion = ASFFV5(level=2, multiplier=multiplier,rfb=rfb)
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
def forward(self, x):
z = [] # inference output
result=[]
self.training |= self.export
result.append(self.l2_fusion(x))
result.append(self.l1_fusion(x))
result.append(self.l0_fusion(x))
x=result
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
y = x[i].sigmoid()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)], indexing='ij')
else:
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
return grid, anchor_grid
边栏推荐
- Advanced Feature Selection Techniques in Linear Models - Based on R
- AVL树的插入--旋转笔记
- 8.STM32F407之HAL库——PWM笔记
- pytest测试框架
- 反转链表中的第m至第n个节点---leetcode
- 基于Qiskit——《量子计算编程实战》读书笔记(六)
- 一篇文章掌握整个JVM,JVM超详细解析!!!
- Talk about API Management - Open Source Edition to SaaS Edition
- pytorch框架学习(7) tensorboard使用
- What are the common commands of mysql
猜你喜欢
How to improve product quality from the code layer
Get started with the OAuth protocol easily with a case
如何在报表控件FastReport.NET中连接XLSX 文件作为数据源?
CSDN Markdown 之我见代码块 | CSDN编辑器测评
OneFlow源码解析:算子指令在虚拟机中的执行
Pony语言学习(八):引用能力(Reference Capabilities)
Zhongang Mining: Strong downstream demand for fluorite
看了几十篇轻量化目标检测论文扫盲做的摘抄笔记
基于BP神经网络的多因素房屋价格预测matlab仿真
scikit-learn机器学习 读书笔记(一)
随机推荐
scikit-learn机器学习 读书笔记(二)
论文精读 —— 2021 CVPR《Progressive Temporal Feature Alignment Network for Video Inpainting》
Pulsar中游标的工作原理
论文精度 —— 2017 ACM《Globally and Locally Consistent Image Completion》
How to use Apifox's Smart Mock function?
pytorch框架学习(7) tensorboard使用
oracle cdc时,设置并行度2插槽数1,最终任务只有一个tm,是不是因为oracle不支持并发
Read the excerpt notes made by dozens of lightweight target detection papers for literacy
如何用Apifox 的智能Mock功能?
Pony语言学习(九)——泛型与模式匹配(终章)
一篇文章掌握整个JVM,JVM超详细解析!!!
FPGA工程师面试试题集锦31~40
【Pei Shu Theorem】CF1055C Lucky Days
基于Qiskit——《量子计算编程实战》读书笔记(四)
AVL tree insertion--rotation notes
pytorch框架学习(9)torchvision.transform
一文带你搞懂OAuth2.0
暑期学前作业
EasyGBS connects to mysql database and prompts "can't connect to mysql server", how to solve it?
Pony语言学习(八):引用能力(Reference Capabilities)