当前位置:网站首页>Yolov5 NMS source code understanding
Yolov5 NMS source code understanding
2022-04-23 21:00:00 【Top of the program】
Put it directly in nms yolov5 Source code , An example of a binary classification model is given , Yes nms The understanding of the
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=(), max_det=100000,return_index = False):
"""Runs Non-Maximum Suppression (NMS) on inference results Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {
conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {
iou_thres}, valid values are between 0.0 and 1.0'
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_nms = 200000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 60.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
i = None
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {
time_limit}s exceeded')
break # time limit exceeded
if return_index == True:
return output, i
else:
return output
- The first two variables :
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
- nc = prediction.shape[2] - 5 # number of classes
prediction Is the direct output of the network model
Output its shape yes (1, 50000, 7), 1 It shows the number of pictures ,50000 Indicates the number of candidate boxes for network prediction ,7 Represents a group of numbers. Its meaning is as follows :
You can see from this ,nc You can get the number of categories predicted by the network .
xc = prediction[…, 4] > conf_thres # candidates
prediction[…, 4] yes shape by torch.Size([1, 50000]) tensor,prediction[…, 4] The meaning is to take the second of all predicted values 5 It's worth , Indicates that the target box contains the probability value of the target , Entire expression prediction[…, 4] > conf_thres, Returns a prediction[…, 4] Have the same shape(1, 50000) Of tensor, The value of each is True perhaps False, And put this tensor Assign a value to xc, therefore xc Of shape yes (1, 50000), Each value is True perhaps False, Express each with its value box Whether the confidence of is greater than or less than conf_thres value .Input parameters conf_thres and iou_thres Of check
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {
conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {
iou_thres}, valid values are between 0.0 and 1.0'
These two lines are simple , yes check conf and iou Whether in 0 and 1 Between , The parameters belong to check.
- Variable definitions
For the definition of variables, please refer to the notes after the variables , Translate it into Chinese
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
( Pixel values ) The smallest and the largest box Width and height
max_nms = 200000 # maximum number of boxes into torchvision.ops.nms()
Transport to torchvision.ops.nms() The largest number in the interface box Total quantity
time_limit = 60.0 # seconds to quit after
nms Function execution timeout setting
redundant = True # require redundant detections
Additional testing is required
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
- Output tensor Definition
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
output Be defined as a list,list The length of is equal to the number of predicted pictures prediction.shape[0], That is to say 1, Every list The element of is an element that contains 6 Of fields tensor.
- use for Loop through each picture in turn , Process the reasoning result of each picture
for xi, x in enumerate(prediction): # image index, image inference
xi It's a picture index, Its value is 0, It means the first one 0 A picture ,x It's the reasoning result of the picture , Its shape yes torch.Size([50000, 7]), This is a clever use of enumerate Picture the index Separate from the reasoning result of the picture , Stored separately in xi and x Inside . Originally prediction It's a shape by torch.Size([1, 50000, 7]) The three-dimensional tensor of , The axis of the tensor 1 It's a picture index, The other two axes of the tensor correspond to the reasoning result of the picture , So this is taking advantage of enumerate take prediction The two parts of are separated . This line of code is clever .
x = x[xc[xi]] # confidence
xc It was calculated before shape by (1, 50000) Of tensor, Each value is True perhaps False, Express each with its value box Whether the confidence of is greater than or less than conf_thres value .xc[xi],[xi] This is the xc tensor How to get the value of ,xi The value of is 0, It means the third 0 All of this picture True perhaps False value .xc[xi] Of shape yes torch.Size([50000]), And x Of the 0 The two dimensions are consistent ,x[xc[xi]] It's also true x tensor How to get the value of ,x[xc[xi]] The whole expression is for x The first 0 On dimensions xc[xi] All for True Reservations , by False Others abandon , That is to say, for the first time i( here i yes 0) Picture all confidence Greater than conf_thres Of box Take out , And then reassign it to x.
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
This code is not sure what it means , Temporary neglect , Find out later and add .
# If none remain process next image
if not x.shape[0]:
continue
Here is the check passed conf_thres Whether the condition is filtered x.shape[0] Greater than 0, If there is more than conf_thres Of box, So further processing . If not greater than conf_thres Of box, Then deal with the next picture .
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
This line calculates the probability of the category predicted by each box ,yolo The formula of the paper is
Pr ( Class i ∣ Object ) ∗ Pr ( Object ) ∗ I O U pred truth = Pr ( Class i ) ∗ I O U pred truth \operatorname{Pr}\left(\text { Class }_{i} \mid \text { Object }\right) * \operatorname{Pr}(\text { Object }) * \mathrm{IOU}_{\text {pred }}^{\text {truth }}=\operatorname{Pr}\left(\text { Class }_{i}\right) * \mathrm{IOU}_{\text {pred }}^{\text {truth }} Pr( Class i∣ Object )∗Pr( Object )∗IOUpred truth =Pr( Class i)∗IOUpred truth
This quote comes from yolo The paper gives us class-specific confidence scores for each
box. These scores encode both the probability of that class
appearing in the box and how well the predicted box fits the
object. It can be seen from this sentence that , Calculated by this formula scores Value is the probability value of each category and the accuracy of the prediction frame to the target frame , And the code doesn't use IOU This part , Only the probability calculation of categories is done .x[:, 5:] The corresponding is P r ( Class i ∣ Object ) {Pr}\left(\text { Class }_{i} \mid \text { Object }\right) Pr( Class i∣ Object ),x[:, 4:5] The corresponding is P r ( Object ) {Pr}(\text { Object }) Pr( Object ).
And the calculation of probability is also very clever ,x[:, 5:] Corresponding 5000 individual 2 Column tensor,x[:, 4:5], Corresponding 5000 individual 1 Column tensor,x[:, 5:] *= x[:, 4:5] Is represented as 2 Each element of a column is multiplied by each element of a column to obtain 2 The elements of the column , The element is then assigned to x[:, 5:] Two columns of . After this calculation x The last two columns of ( The first 6 Column , The first 7 Column ) The value of represents the category of the target confidence The value of .
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
original yolo Output box The format is center x, center y, width, height, After this step, you will box The four values of are expressed as (x1, y1, x2, y2)
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
The above code is built nx6 Detection matrix , about multi_label, I don't understand yet , Wait until you understand .
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
It's used here tensor Of max function , For the detailed usage of this function, please refer to https://www.jianshu.com/p/3ed11362b54f,max The input parameter for is 1 Means to find the maximum value for each line , The function returns two tensor, first tensor Is the maximum per line ; the second tensor Is the index of the maximum value per row .keepdim You can refer to https://blog.csdn.net/zylooooooooong/article/details/112576268, Indicates whether the output dimension is consistent with the input dimension ,True Is consistent .
Function return conf Of shape yes torch.Size([50000, 1]), Each line is x[:, 5:] The maximum of ,j Is the index of the maximum value of each row , This index is ultimately used to represent nms For each category of output id.
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
This line will box,conf, as well as j According to the column cat Become a tensor As the output of the network .[conf.view(-1) > conf_thres] Is to screen out confidence Greater than conf_thres all box.
Code execution to this line , Basically, even if the confidence is greater than conf_thres All of the box Screened out , The screening results are stored in x tensor Inside .x Of 0 To 3 Column storage box,4 Column storage conf,5 Column storage class species id.
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
This line is using class To filter , Filter out the specified class,nms Only for specified class Conduct nms.
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
Judge x Whether contains box, without box, Then proceed to the next picture nms, If box The number of exceeds the maximum nms The number is based on confidence Values in descending order , Take out the biggest nms The number of box do nms.
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) This line of code is applied in multiple categories NMS For specific significance, please refer to https://blog.csdn.net/flyfish1986/article/details/119177472.
Multiple categories NMS( Non maximum suppression ) Our processing strategy is to allow each class to execute independently NMS, Add an offset to all borders . The offset depends only on the class ID( That is to say x[:, 5:6]), And big enough , So that boxes from different classes do not overlap .
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
Take out this line boxes and scores,boxes Added offset c, The offset size of the non passing category is inconsistent .
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
call torch Self contained nms The interface realizes the suppression of overlapping boxes , The function returns a tensor i
Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
i The meaning is , integer 64 tensor , Indicates the... Of the reserved box index, In addition, according to the score ( Degree of confidence ) From high to low .
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
Determine whether the maximum nms Number of tests , If exceeded , Then remove those with low confidence .
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
I don't understand this part for the time being , I'll update it after I understand it
output[xi] = x[i]
x[i] It's using nms Result i, Remove all nms result i Corresponding box, Then save the results to xi The picture corresponds to output Inside .
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {
time_limit}s exceeded')
break # time limit exceeded
This line is judgment nms If the timeout , If the timeout , Then jump out for loop , Don't move on to the next picture nms
return output
Return results , The result is stored in output Inside the ,output It corresponds to a list,list An element corresponds to the of a picture nms result , For the two categories exemplified in this paper , There's only one picture , So you can see output[0], confirm nms Output result of , Its shape yes torch.Size([1892, 6]), If you print one of the lines , The corresponding format is as follows , That is to say nms result , The meaning of each line :
版权声明
本文为[Top of the program]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/111/202204210545091270.html
边栏推荐
猜你喜欢
随机推荐
MySQL basic collection
【SDU Chart Team - Core】SVG属性类设计之枚举
一些接地气的话儿
电脑越用越慢怎么办?文件误删除恢复方法
Sequential state
浅谈数据库设计之三大范式
go struct
Addition, deletion, modification and query of MySQL advanced table
2. Finishing huazi Mianjing -- 2
Problem brushing plan -- dynamic programming (III)
Two Stage Detection
Alibaba cloud responded to the disclosure of user registration information
Reentrant function
Write table of MySQL Foundation (create table)
又一款数据分析神器:Polars 真的很强大
Flomo software recommendation
Assertionerror: invalid device ID and runtimeerror: CUDA error: invalid device ordinal
Question brushing plan - depth first search (II)
Tensorflow1. X and 2 How does x read those parameters saved in CKPT