diff --git a/layers/box_utils.py b/layers/box_utils.py index 84214947b..bea7bfc96 100644 --- a/layers/box_utils.py +++ b/layers/box_utils.py @@ -186,7 +186,7 @@ def nms(boxes, scores, overlap=0.5, top_k=200): keep = scores.new(scores.size(0)).zero_().long() if boxes.numel() == 0: - return keep + return keep, 0 x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] diff --git a/layers/functions/detection.py b/layers/functions/detection.py index 0d1ef8d30..2c26eeceb 100644 --- a/layers/functions/detection.py +++ b/layers/functions/detection.py @@ -52,6 +52,8 @@ def forward(self, loc_data, conf_data, prior_data): boxes = decoded_boxes[l_mask].view(-1, 4) # idx of highest scoring and non-overlapping boxes per class ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) + if count == 0: + continue output[i, cl, :count] = \ torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1)