当前位置:网站首页>pytorch,numpy两种方法实现nms类间+类内
pytorch,numpy两种方法实现nms类间+类内
2022-08-11 05:41:00 【gy-77】
类间:也就是不同类之间也进行nms
类内:就是只把同类的bboxes进行nms
numpy实现 nms类间+类内:
import numpy as np
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
indices = scores.argsort()[::-1]
indice_res = []
while indices.size > 0:
i = indices[0]
indice_res.append(i)
inter_x1 = np.maximum(x1[i], x1[indices[1:]])
inter_y1 = np.maximum(y1[i], y1[indices[1:]])
inter_x2 = np.minimum(x2[i], x2[indices[1:]])
inter_y2 = np.minimum(y2[i], y2[indices[1:]])
inter_w = np.maximum(0.0, inter_x2 - inter_x1 + 1)
inter_h = np.maximum(0.0, inter_y2 - inter_y1 + 1)
inter_area = inter_w * inter_h
union_area = areas[i] + areas[indices[1:]] - inter_area + 1e-6
ious = inter_area / union_area
idxs = np.where(ious < thresh)[0] # np.where(ious < thresh)返回的是一个tuple,第一个元素是一个满足条件的array
indices = indices[idxs + 1]
return indice_res
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: np.array, shape of (N, 4), N is the number of bboxes, np.float32 :param scores: np.array, shape of (N, 1), np.float32 :param cat_ids: np.array, shape of (N, 1),np.int32 :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个足够大的偏移量,使不同类别的bboxes不会相交
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
torch实现 nms类间+类内:
import torch
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
# values, indices = torch.sort(scores, descending=True)
indices = scores.sort(descending=True)[1] # torch
indice_res = torch.randn([1, 4]).to(bboxes)
while indices.size()[0] > 0: # indices.size()是一个Size对象,我们要取第一个元素是int,才能比较
save_idx, other_idx = indices[0], indices[1:]
indice_res = torch.cat((indice_res, bboxes[save_idx].unsqueeze(0)),
dim=0) # unsqueeze是添加一个维度,让bboxes.shape从[4]-->[1,4]
inter_x1 = torch.max(x1[save_idx], x1[other_idx])
inter_y1 = torch.max(y1[save_idx], y1[other_idx])
inter_x2 = torch.min(x2[save_idx], x2[other_idx])
inter_y2 = torch.min(y2[save_idx], y2[other_idx])
inter_w = torch.max(inter_x2 - inter_x1 + 1, torch.tensor(0).to(bboxes))
inter_h = torch.max(inter_y2 - inter_y1 + 1, torch.tensor(0).to(bboxes))
inter_area = inter_w * inter_h
union_area = areas[save_idx] + areas[other_idx] - inter_area + 1e-6
iou = inter_area / union_area
indices = other_idx[iou < thresh]
return indice_res[1:]
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: torch.tensor([n, 4], dtype=torch.float32) :param scores: torch.tensor([n], dtype=torch.float32) :param cat_ids: torch.tensor([n], dtype=torch.int32) :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个很大的偏移量
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
边栏推荐
猜你喜欢
Douyin get douyin share password url API return value description
八股文之redis
My meeting of the OA project (meeting seating & review)
Daily sql--statistics the total salary of employees in the past three months (excluding the latest month)
HCIP-BGP的选路实验
HCIP OSPF动态路由协议
导航定位中的坐标系
获取拼多多商品信息操作详情
HCIP MPLS/BGP综合实验
MySQL01
随机推荐
类的定义、类的继承以及super的使用
Cobbleland 博览会 基础系列 1
OA项目之我的会议(会议排座&送审)
sql--7天内(含当天)购买次数超过3次(含),且近7天的购买金额超过1000的用户
Douyin get douyin share password url API return value description
HCIP BGP built adjacent experiment
Find the shops that have sold more than 1,000 yuan per day for more than 30 consecutive days in the past six months
抖音API接口大全
微信小程序功能上新(2022.06.01~2022.08.04)
OA Project Pending Meeting & History Meeting & All Meetings
radix-4 FFT 原理和C语言代码实现
姿态解算-陀螺仪+欧拉法
Daily sql-seek the sum of successful investments in 2016
numpy和tensor增加或删除一个维度
The ramdisk practice 1: the root file system integrated into the kernel
亚马逊API接口大全
李沐d2l(十)--卷积层Ⅰ
Daily sql--statistics the total salary of employees in the past three months (excluding the latest month)
损失函数——负对数似然
ssh服务攻防与加固