From 1932ff904dea89947b2b3a2857bd780789d4b653 Mon Sep 17 00:00:00 2001 From: Rabbit314271 <59315380+Rabbit314271@users.noreply.github.com> Date: Wed, 20 Oct 2021 15:17:30 +0800 Subject: [PATCH] Update coco.py --- datasets/COCO/coco.py | 51 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/datasets/COCO/coco.py b/datasets/COCO/coco.py index 1f90d8b..6d83f1a 100755 --- a/datasets/COCO/coco.py +++ b/datasets/COCO/coco.py @@ -21,7 +21,56 @@ # from dataset.JointsDataset import JointsDataset from .JointsDataset import JointsDataset -from .nms.nms import oks_nms + +def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None): + if not isinstance(sigmas, np.ndarray): + sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 + vars = (sigmas * 2) ** 2 + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + ious = np.zeros((d.shape[0])) + for n_d in range(0, d.shape[0]): + xd = d[n_d, 0::3] + yd = d[n_d, 1::3] + vd = d[n_d, 2::3] + dx = xd - xg + dy = yd - yg + e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 + if in_vis_thre is not None: + ind = list(vg > in_vis_thre) and list(vd > in_vis_thre) + e = e[ind] + ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0 + return ious + +def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): + """ + greedily select boxes with high confidence and overlap with current maximum <= thresh + rule out overlap >= thresh, overlap = oks + :param kpts_db + :param thresh: retain overlap < thresh + :return: indexes to keep + """ + if len(kpts_db) == 0: + return [] + + scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))]) + kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))]) + areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))]) + + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + + oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre) + + inds = np.where(oks_ovr <= thresh)[0] + order = order[inds + 1] + + return keep logger = logging.getLogger(__name__)