当前位置:网站首页>pytorch,numpy两种方法实现nms类间+类内
pytorch,numpy两种方法实现nms类间+类内
2022-08-11 05:41:00 【gy-77】
类间:也就是不同类之间也进行nms
类内:就是只把同类的bboxes进行nms
numpy实现 nms类间+类内:
import numpy as np
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
indices = scores.argsort()[::-1]
indice_res = []
while indices.size > 0:
i = indices[0]
indice_res.append(i)
inter_x1 = np.maximum(x1[i], x1[indices[1:]])
inter_y1 = np.maximum(y1[i], y1[indices[1:]])
inter_x2 = np.minimum(x2[i], x2[indices[1:]])
inter_y2 = np.minimum(y2[i], y2[indices[1:]])
inter_w = np.maximum(0.0, inter_x2 - inter_x1 + 1)
inter_h = np.maximum(0.0, inter_y2 - inter_y1 + 1)
inter_area = inter_w * inter_h
union_area = areas[i] + areas[indices[1:]] - inter_area + 1e-6
ious = inter_area / union_area
idxs = np.where(ious < thresh)[0] # np.where(ious < thresh)返回的是一个tuple,第一个元素是一个满足条件的array
indices = indices[idxs + 1]
return indice_res
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: np.array, shape of (N, 4), N is the number of bboxes, np.float32 :param scores: np.array, shape of (N, 1), np.float32 :param cat_ids: np.array, shape of (N, 1),np.int32 :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个足够大的偏移量,使不同类别的bboxes不会相交
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
torch实现 nms类间+类内:
import torch
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
# values, indices = torch.sort(scores, descending=True)
indices = scores.sort(descending=True)[1] # torch
indice_res = torch.randn([1, 4]).to(bboxes)
while indices.size()[0] > 0: # indices.size()是一个Size对象,我们要取第一个元素是int,才能比较
save_idx, other_idx = indices[0], indices[1:]
indice_res = torch.cat((indice_res, bboxes[save_idx].unsqueeze(0)),
dim=0) # unsqueeze是添加一个维度,让bboxes.shape从[4]-->[1,4]
inter_x1 = torch.max(x1[save_idx], x1[other_idx])
inter_y1 = torch.max(y1[save_idx], y1[other_idx])
inter_x2 = torch.min(x2[save_idx], x2[other_idx])
inter_y2 = torch.min(y2[save_idx], y2[other_idx])
inter_w = torch.max(inter_x2 - inter_x1 + 1, torch.tensor(0).to(bboxes))
inter_h = torch.max(inter_y2 - inter_y1 + 1, torch.tensor(0).to(bboxes))
inter_area = inter_w * inter_h
union_area = areas[save_idx] + areas[other_idx] - inter_area + 1e-6
iou = inter_area / union_area
indices = other_idx[iou < thresh]
return indice_res[1:]
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: torch.tensor([n, 4], dtype=torch.float32) :param scores: torch.tensor([n], dtype=torch.float32) :param cat_ids: torch.tensor([n], dtype=torch.int32) :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个很大的偏移量
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
边栏推荐
猜你喜欢
HCIP MGRE\OSPF综合实验
Daily sql-seek the sum of successful investments in 2016
Attitude solution - gyroscope + Euler method
OA project meeting notice (query & whether attending & feedback for details)
Concurrent programming in eight-part essay
抖音API接口大全
Implement general-purpose, high-performance sorting and quicksort optimizations
OA项目之我的会议(会议排座&送审)
技能在赛题解析:交换机防环路设置
空间金字塔池化 -Spatial Pyramid Pooling(含源码)
随机推荐
京东商品详情API调用实例讲解
HCIP MPLS/BGP Comprehensive Experiment
Implement general-purpose, high-performance sorting and quicksort optimizations
concept noun
Xshell如何连接虚拟机
torch.cat()使用方法
姿态解算-陀螺仪+欧拉法
daily sql - query for managers and elections with at least 5 subordinates
HCIP-BGP的选路实验
每日sql-找到每个学校gpa最低的同学(开窗)
MySQL之函数
Attitude solution - gyroscope + Euler method
矩阵分析——微分、积分、极限
HCIP BGP建邻、联邦、汇总实验
《Show and Tell: A Neural Image Caption Generator》论文解读
Eight-legged text jvm
ROS 服务通信理论模型
HCIP-生成树(802.1D ,标准生成树/802.1W : RSTP 快速生成树/802.1S : MST 多生成树)
Multiscale communication in cortical-cortical networks
Daily sql-seek the sum of successful investments in 2016