当前位置:网站首页>pytorch,numpy两种方法实现nms类间+类内
pytorch,numpy两种方法实现nms类间+类内
2022-08-11 05:41:00 【gy-77】
类间:也就是不同类之间也进行nms
类内:就是只把同类的bboxes进行nms
numpy实现 nms类间+类内:
import numpy as np
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
indices = scores.argsort()[::-1]
indice_res = []
while indices.size > 0:
i = indices[0]
indice_res.append(i)
inter_x1 = np.maximum(x1[i], x1[indices[1:]])
inter_y1 = np.maximum(y1[i], y1[indices[1:]])
inter_x2 = np.minimum(x2[i], x2[indices[1:]])
inter_y2 = np.minimum(y2[i], y2[indices[1:]])
inter_w = np.maximum(0.0, inter_x2 - inter_x1 + 1)
inter_h = np.maximum(0.0, inter_y2 - inter_y1 + 1)
inter_area = inter_w * inter_h
union_area = areas[i] + areas[indices[1:]] - inter_area + 1e-6
ious = inter_area / union_area
idxs = np.where(ious < thresh)[0] # np.where(ious < thresh)返回的是一个tuple,第一个元素是一个满足条件的array
indices = indices[idxs + 1]
return indice_res
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: np.array, shape of (N, 4), N is the number of bboxes, np.float32 :param scores: np.array, shape of (N, 1), np.float32 :param cat_ids: np.array, shape of (N, 1),np.int32 :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个足够大的偏移量,使不同类别的bboxes不会相交
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
torch实现 nms类间+类内:
import torch
# 类间nms
def nms(bboxes, scores, thresh):
x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# 按照score降序排序(保存的是索引)
# values, indices = torch.sort(scores, descending=True)
indices = scores.sort(descending=True)[1] # torch
indice_res = torch.randn([1, 4]).to(bboxes)
while indices.size()[0] > 0: # indices.size()是一个Size对象,我们要取第一个元素是int,才能比较
save_idx, other_idx = indices[0], indices[1:]
indice_res = torch.cat((indice_res, bboxes[save_idx].unsqueeze(0)),
dim=0) # unsqueeze是添加一个维度,让bboxes.shape从[4]-->[1,4]
inter_x1 = torch.max(x1[save_idx], x1[other_idx])
inter_y1 = torch.max(y1[save_idx], y1[other_idx])
inter_x2 = torch.min(x2[save_idx], x2[other_idx])
inter_y2 = torch.min(y2[save_idx], y2[other_idx])
inter_w = torch.max(inter_x2 - inter_x1 + 1, torch.tensor(0).to(bboxes))
inter_h = torch.max(inter_y2 - inter_y1 + 1, torch.tensor(0).to(bboxes))
inter_area = inter_w * inter_h
union_area = areas[save_idx] + areas[other_idx] - inter_area + 1e-6
iou = inter_area / union_area
indices = other_idx[iou < thresh]
return indice_res[1:]
# 类内nms,把不同类别的乘以一个偏移量,把不同类别的bboxes给偏移到不同位置。
def class_nms(bboxes, scores, cat_ids, iou_threshold):
''' :param bboxes: torch.tensor([n, 4], dtype=torch.float32) :param scores: torch.tensor([n], dtype=torch.float32) :param cat_ids: torch.tensor([n], dtype=torch.int32) :param iou_threshold: float '''
max_coordinate = bboxes.max()
# 为每一个类别/每一层生成一个很大的偏移量
offsets = cat_ids * (max_coordinate + 1)
# bboxes加上对应类别的偏移量后,保证不同类别之间bboxes不会有重合的现象
bboxes_for_nms = bboxes + offsets[:, None]
indice_res = nms(bboxes_for_nms, scores, iou_threshold)
return indice_res
边栏推荐
- Class definition, class inheritance, and the use of super
- HCIA experiment
- 图文带你理解什么是Few-shot Learning
- Eight-legged text of mysql
- 【预约观看】Ambire 智能钱包 AMA 活动第四期即将举行
- 淘宝sku API 接口(PHP示例)
- HCIP MGRE\OSPF综合实验
- OA project meeting notice (query & whether attending & feedback for details)
- MySQL之函数
- My meeting of the OA project (meeting seating & review)
猜你喜欢
随机推荐
Open Set Domain Adaptation 开集领域适应
皮质-皮质网络的多尺度交流
Pinduoduo API interface (attach my available API)
Class definition, class inheritance, and the use of super
HCIP OSPF/MGRE综合实验
强烈推荐一款好用的API接口
Taobao sku API interface (PHP example)
淘宝API接口参考
pytorch调整模型学习率
每日sql-求2016年成功的投资总和
Douyin API interface
类的定义、类的继承以及super的使用
Amazon API interface Daquan
Do not add the is prefix to the variables of the boolean type in the POJO class of the Alibaba specification
Multiscale communication in cortical-cortical networks
淘宝API常用接口与获取方式
Trill keyword search goods - API
获取拼多多商品信息操作详情
unable to extend table xxx by 1024 in tablespace xxxx
unable to extend table xxx by 1024 in tablespace xxxx









