当前位置:网站首页>Dabdetr paper interpretation + core source code interpretation
Dabdetr paper interpretation + core source code interpretation
2022-04-21 14:33:00 【Wu lele~】
List of articles
Preface
This article is mainly introduced and published in ICLR2022 Of DAB-Detr The basic idea of the paper and the implementation of the code .
1、 Code address
2、 Address of thesis
in addition , If you are interested, you can read what I wrote about detr Other articles :
1、nn.Transformer Use
2、mmdet Reading Detr
3、DeformableDetr
4、ConditionalDetr
1、 Interpretation of the thesis
Overall model structure diagram and Detr Very similar :

1.1. Spatial attention heat map visualization

This paper holds that the original Detr In a series of papers : Learnable object queries Just for model forecast bbox Provides reference points ( Center point ) Information , But it didn't provide box Width and height information . therefore , In this paper, we consider introducing a learnable anchor box model Can adapt to objects of different sizes . The above figure is the spatial attention heat map of three visual models (pk*pq), If the reader is interested in how the heat map is generated , May refer to Detr Heat map visualization . As you can see from the diagram , After introducing the learnable anchor box ,DAB-Detr It can cover objects of different sizes well . A conclusion drawn in this paper :query in content query and key Calculate the similarity and complete the feature extraction , and pos query Is used to limit the range and size of the extraction area .
1.2. Model draft

Purple in the picture is the changed area , The general process is :DAB-Detr Directly preset N A learnable anchor, This is similar to SparseRCNN. Then through the width height modulation cross attention module , Predict the four element offsets of each anchor box to update anchor.
1.3. Detailed model

The picture above is a picture I made PPT, It shows the first floor DecoderLayer. Simply put, the next process : First set up N A learnable 4 Dimensional anchors, And then pass by PE and MLP Map it to Pq.
1) stay self-attn part : Regular self attention , It uses Cq and Pq add ;
2) stay cross-attn part : Reference point (x,y) Partially and completely ConditionalDetr equally ,Cq and Pq Use Splicing To generate Qq; The only difference is “ Wide and high modulation cross attention module ”: In the calculation Pk and Pq A weight similarity is introduced (1/w,1/h) A scale transformation operation of .
1.4. Set temperature coefficient
Detr Generate position for each position in the feature map Pk The full use is Transformer Medium temperature coefficient , and Transformer It is designed for the embedded vector of words , The pixel values in the feature map are mostly distributed in [0,1] Between , therefore , Rashly adopt 10000 Don't fit , therefore , This paper adopts 20. It's a trick Well , It can rise by about a point .

1.5. experiment
In four backbone The performance is compared , overall , To achieve the best .

2、 Code explanation
I feel that the quality of this set of code is very high , Because the author basically opened the code of each experiment , It's worth watching again and again ( Include deformable attn The operator of 、 Distributed training and so on ).
2.1.Decoder
First look at the whole Decoder Of forward Function part :
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 4
):
# first floor tgt Initialize all 0,output That is, the input Cq!
output = tgt
# Save intermediate results
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid() # [300,batch,4]
ref_points = [reference_points]
# import ipdb; ipdb.set_trace()
for layer_id, layer in enumerate(self.layers):
# Take out anchor Center of Aq
obj_center = reference_points[..., :self.query_dim] # [num_queries, batch_size, 2]
# perform Pq = MLP(PE(obj_center)), Turn the center point into 256 Embedding vector of dimension
query_sine_embed = gen_sineembed_for_position(obj_center)
query_pos = self.ref_point_head(query_sine_embed)
# For the first decoder layer, we do not apply transformation over p_s
if self.query_scale_type != 'fix_elewise':
if layer_id == 0:
pos_transformation = 1
# Cq after MLP Get the transformation for the center
else:
pos_transformation = self.query_scale(output)
else:
pos_transformation = self.query_scale.weight[layer_id]
# obtain Pq
query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation
# modulated HW attentions
if self.modulate_hw_attn:
# Cq after MLP and sigmoid obtain Wq,ref and Hq,ref
refHW_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 2
# Application wide and high modulation loss
query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)
# Execute the of the current layer decoder layer
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0))
# iter update
if self.bbox_embed is not None:
if self.bbox_embed_diff_each_layer:
# stay Cq Based on the prediction tmp: namely bbox Error quantity of :[delta_x, delta_y, delta_w, delta_h]
tmp = self.bbox_embed[layer_id](output)
else:
tmp = self.bbox_embed(output)
# to update bbox
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
# after sigmoid Get new bbox
new_reference_points = tmp[..., :self.query_dim].sigmoid()
if layer_id != self.num_layers - 1:
# Store reference points for each layer
ref_points.append(new_reference_points)
# Update reference point , For the next level decoder layer Use
reference_points = new_reference_points.detach()
# Save the middle Cq
if self.return_intermediate:
intermediate.append(self.norm(output))
# The loop ends , Return the desired value as required
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
if self.bbox_embed is not None:
return [
torch.stack(intermediate).transpose(1, 2),
torch.stack(ref_points).transpose(1, 2),
]
else:
return [
torch.stack(intermediate).transpose(1, 2),
reference_points.unsqueeze(0).transpose(1, 2)
]
return output.unsqueeze(0)
2.2.DecoderLayer
The internal is to call self-attn and cross-attn,pq,pk,cq,ck According to the addition or splicing in the paper .
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed = None,
is_first = False):
# ========== Begin of Self-Attention =============
if not self.rm_self_attn_decoder:
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default.
q_pos = self.sa_qpos_proj(query_pos)
k_content = self.sa_kcontent_proj(tgt)
k_pos = self.sa_kpos_proj(query_pos)
v = self.sa_v_proj(tgt)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
# Self attention : Add up
q = q_content + q_pos
k = k_content + k_pos
tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
# ========== End of Self-Attention =============
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
k_pos = self.ca_kpos_proj(pos)
# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first or self.keep_query_pos:
q_pos = self.ca_qpos_proj(query_pos)
q = q_content + q_pos
k = k_content + k_pos
else:
q = q_content
k = k_content
# Split into multiple heads and cq and pq Splicing
q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# Split into multiple heads and ck and pk Splicing
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
# call nn.MultiHeadAttn modular
tgt2 = self.cross_attn(query=q,
key=k,
value=v, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
# ========== End of Cross-Attention =============
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
summary
Later on DN-DETR, Coming soon .
版权声明
本文为[Wu lele~]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204211430473162.html
边栏推荐
猜你喜欢

How to provide CPU branch prediction efficiency at the code level

程序员想在深圳扎根,除了去腾讯,还可以考虑一下这些公司

如何解决MySQL中的死锁问题?

股价暴跌 Robinhood收购英国加密公司求扩张

Script operation es

数据仓库架构演变和建设思路

Get rid of the messy if else in the project and try the state mode. This is the elegant implementation

【Groovy】MOP 元对象协议与元编程 ( 使用 Groovy 元编程进行函数拦截 | 通过 MetaClass#invokeMethod 方法调用类其它方法 )

I took out 38K from Tencent and showed me the basic ceiling

如何在excel中插入文件?Excel插入对象和附件有什么区别?(插入对象能直接显示内容,但我没显示?)
随机推荐
Detect and open WhatsApp
虫子 队列
leetcode答题笔记(一)
虫子 自定义类型
使用枚举做的红绿灯,有界面
五个拿来就能用的炫酷登录页面
IK分词器
Network security: introduce five common encryption algorithms
虫子 PWM
Use go language to complete the student information management system through restful API
or1k启动文件分析
LNK2001 - unresolved external symbol in PCL test program
【Groovy】MOP 元对象协议与元编程 ( 使用 Groovy 元编程进行函数拦截 | 动态拦截函数 | 动态获取 MetaClass 中的方法 | evaluate 方法执行Groovy脚本 )
超级实用的Chrome插件
The use of toString and wrapper class
赏金猎人自动交易机器人开发模式分析
Kubernetes 疑难问题排查 - 10s 延迟
游戏+NFT,脱虚向实外的另一可行场景
我们还能依赖Play to Earn经济获利多久?
Worm ring list