当前位置:网站首页>Yolov5 NMS source code understanding

Yolov5 NMS source code understanding

2022-04-23 21:00:00 Top of the program

Put it directly in nms yolov5 Source code , An example of a binary classification model is given , Yes nms The understanding of the

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=(), max_det=100000,return_index = False):
    """Runs Non-Maximum Suppression (NMS) on inference results Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """

    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {
      conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {
      iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_nms = 200000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 60.0  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    i = None
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        # x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {
      time_limit}s exceeded')
            break  # time limit exceeded
    if return_index == True:
        return output, i
    else:
        return output

  • The first two variables :
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates
  • nc = prediction.shape[2] - 5 # number of classes

prediction Is the direct output of the network model
Output its shape yes (1, 50000, 7), 1 It shows the number of pictures ,50000 Indicates the number of candidate boxes for network prediction ,7 Represents a group of numbers. Its meaning is as follows :
 Insert picture description here
You can see from this ,nc You can get the number of categories predicted by the network .

  • xc = prediction[…, 4] > conf_thres # candidates
    prediction[…, 4] yes shape by torch.Size([1, 50000]) tensor,prediction[…, 4] The meaning is to take the second of all predicted values 5 It's worth , Indicates that the target box contains the probability value of the target , Entire expression prediction[…, 4] > conf_thres, Returns a prediction[…, 4] Have the same shape(1, 50000) Of tensor, The value of each is True perhaps False, And put this tensor Assign a value to xc, therefore xc Of shape yes (1, 50000), Each value is True perhaps False, Express each with its value box Whether the confidence of is greater than or less than conf_thres value .

  • Input parameters conf_thres and iou_thres Of check

assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {
      conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {
      iou_thres}, valid values are between 0.0 and 1.0'

These two lines are simple , yes check conf and iou Whether in 0 and 1 Between , The parameters belong to check.

  • Variable definitions
    For the definition of variables, please refer to the notes after the variables , Translate it into Chinese
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    ( Pixel values ) The smallest and the largest box Width and height 
    max_nms = 200000  # maximum number of boxes into torchvision.ops.nms()
     Transport to torchvision.ops.nms() The largest number in the interface box Total quantity 
    time_limit = 60.0  # seconds to quit after
    nms Function execution timeout setting 
    redundant = True  # require redundant detections
     Additional testing is required 
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS
  • Output tensor Definition
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]

output Be defined as a list,list The length of is equal to the number of predicted pictures prediction.shape[0], That is to say 1, Every list The element of is an element that contains 6 Of fields tensor.

  • use for Loop through each picture in turn , Process the reasoning result of each picture
for xi, x in enumerate(prediction):  # image index, image inference

xi It's a picture index, Its value is 0, It means the first one 0 A picture ,x It's the reasoning result of the picture , Its shape yes torch.Size([50000, 7]), This is a clever use of enumerate Picture the index Separate from the reasoning result of the picture , Stored separately in xi and x Inside . Originally prediction It's a shape by torch.Size([1, 50000, 7]) The three-dimensional tensor of , The axis of the tensor 1 It's a picture index, The other two axes of the tensor correspond to the reasoning result of the picture , So this is taking advantage of enumerate take prediction The two parts of are separated . This line of code is clever .

x = x[xc[xi]]  # confidence

xc It was calculated before shape by (1, 50000) Of tensor, Each value is True perhaps False, Express each with its value box Whether the confidence of is greater than or less than conf_thres value .xc[xi],[xi] This is the xc tensor How to get the value of ,xi The value of is 0, It means the third 0 All of this picture True perhaps False value .xc[xi] Of shape yes torch.Size([50000]), And x Of the 0 The two dimensions are consistent ,x[xc[xi]] It's also true x tensor How to get the value of ,x[xc[xi]] The whole expression is for x The first 0 On dimensions xc[xi] All for True Reservations , by False Others abandon , That is to say, for the first time i( here i yes 0) Picture all confidence Greater than conf_thres Of box Take out , And then reassign it to x.

# Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

This code is not sure what it means , Temporary neglect , Find out later and add .

# If none remain process next image
        if not x.shape[0]:
            continue

Here is the check passed conf_thres Whether the condition is filtered x.shape[0] Greater than 0, If there is more than conf_thres Of box, So further processing . If not greater than conf_thres Of box, Then deal with the next picture .

# Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

This line calculates the probability of the category predicted by each box ,yolo The formula of the paper is
Pr ⁡ (  Class  i ∣  Object  ) ∗ Pr ⁡ (  Object  ) ∗ I O U pred  truth  = Pr ⁡ (  Class  i ) ∗ I O U pred  truth  \operatorname{Pr}\left(\text { Class }_{i} \mid \text { Object }\right) * \operatorname{Pr}(\text { Object }) * \mathrm{IOU}_{\text {pred }}^{\text {truth }}=\operatorname{Pr}\left(\text { Class }_{i}\right) * \mathrm{IOU}_{\text {pred }}^{\text {truth }} Pr( Class i Object )Pr( Object )IOUpred truth =Pr( Class i)IOUpred truth 
This quote comes from yolo The paper gives us class-specific confidence scores for each
box. These scores encode both the probability of that class
appearing in the box and how well the predicted box fits the
object.
It can be seen from this sentence that , Calculated by this formula scores Value is the probability value of each category and the accuracy of the prediction frame to the target frame , And the code doesn't use IOU This part , Only the probability calculation of categories is done .x[:, 5:] The corresponding is P r (  Class  i ∣  Object  ) {Pr}\left(\text { Class }_{i} \mid \text { Object }\right) Pr( Class i Object ),x[:, 4:5] The corresponding is P r (  Object  ) {Pr}(\text { Object }) Pr( Object ).
And the calculation of probability is also very clever ,x[:, 5:] Corresponding 5000 individual 2 Column tensor,x[:, 4:5], Corresponding 5000 individual 1 Column tensor,x[:, 5:] *= x[:, 4:5] Is represented as 2 Each element of a column is multiplied by each element of a column to obtain 2 The elements of the column , The element is then assigned to x[:, 5:] Two columns of . After this calculation x The last two columns of ( The first 6 Column , The first 7 Column ) The value of represents the category of the target confidence The value of .

# Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

original yolo Output box The format is center x, center y, width, height, After this step, you will box The four values of are expressed as (x1, y1, x2, y2)

# Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

The above code is built nx6 Detection matrix , about multi_label, I don't understand yet , Wait until you understand .

conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

It's used here tensor Of max function , For the detailed usage of this function, please refer to https://www.jianshu.com/p/3ed11362b54f,max The input parameter for is 1 Means to find the maximum value for each line , The function returns two tensor, first tensor Is the maximum per line ; the second tensor Is the index of the maximum value per row .keepdim You can refer to https://blog.csdn.net/zylooooooooong/article/details/112576268, Indicates whether the output dimension is consistent with the input dimension ,True Is consistent .
Function return conf Of shape yes torch.Size([50000, 1]), Each line is x[:, 5:] The maximum of ,j Is the index of the maximum value of each row , This index is ultimately used to represent nms For each category of output id.

x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

This line will box,conf, as well as j According to the column cat Become a tensor As the output of the network .[conf.view(-1) > conf_thres] Is to screen out confidence Greater than conf_thres all box.
Code execution to this line , Basically, even if the confidence is greater than conf_thres All of the box Screened out , The screening results are stored in x tensor Inside .x Of 0 To 3 Column storage box,4 Column storage conf,5 Column storage class species id.

# Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

This line is using class To filter , Filter out the specified class,nms Only for specified class Conduct nms.

# Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

Judge x Whether contains box, without box, Then proceed to the next picture nms, If box The number of exceeds the maximum nms The number is based on confidence Values in descending order , Take out the biggest nms The number of box do nms.

# Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

c = x[:, 5:6] * (0 if agnostic else max_wh) This line of code is applied in multiple categories NMS For specific significance, please refer to https://blog.csdn.net/flyfish1986/article/details/119177472.
Multiple categories NMS( Non maximum suppression ) Our processing strategy is to allow each class to execute independently NMS, Add an offset to all borders . The offset depends only on the class ID( That is to say x[:, 5:6]), And big enough , So that boxes from different classes do not overlap .
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
Take out this line boxes and scores,boxes Added offset c, The offset size of the non passing category is inconsistent .
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
call torch Self contained nms The interface realizes the suppression of overlapping boxes , The function returns a tensor i

Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores

i The meaning is , integer 64 tensor , Indicates the... Of the reserved box index, In addition, according to the score ( Degree of confidence ) From high to low .

        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]

Determine whether the maximum nms Number of tests , If exceeded , Then remove those with low confidence .

        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

I don't understand this part for the time being , I'll update it after I understand it

output[xi] = x[i]

x[i] It's using nms Result i, Remove all nms result i Corresponding box, Then save the results to xi The picture corresponds to output Inside .

        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {
      time_limit}s exceeded')
            break  # time limit exceeded

This line is judgment nms If the timeout , If the timeout , Then jump out for loop , Don't move on to the next picture nms

return output

Return results , The result is stored in output Inside the ,output It corresponds to a list,list An element corresponds to the of a picture nms result , For the two categories exemplified in this paper , There's only one picture , So you can see output[0], confirm nms Output result of , Its shape yes torch.Size([1892, 6]), If you print one of the lines , The corresponding format is as follows , That is to say nms result , The meaning of each line :
 Insert picture description here

版权声明
本文为[Top of the program]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/111/202204210545091270.html