Skip to content

Commit 08fe9d6

Browse files
authored
Fix empty anno - merge back develop (#4022)
Refactor empty label workaround in iseg and mask_target.py
1 parent 92d25e7 commit 08fe9d6

File tree

4 files changed

+153
-7
lines changed

4 files changed

+153
-7
lines changed

src/otx/algo/common/utils/bbox_overlaps.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from __future__ import annotations
1010

11+
import warnings
12+
1113
import torch
1214
from torch import Tensor
1315

@@ -142,15 +144,28 @@ def bbox_overlaps(
142144
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
143145
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
144146
"""
147+
if not (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0):
148+
msg = "bboxes1 must have a last dimension of size 4 or be an empty tensor."
149+
raise ValueError(msg)
150+
151+
if not (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0):
152+
msg = "bboxes2 must have a last dimension of size 4 or be an empty tensor."
153+
raise ValueError(msg)
154+
155+
if bboxes1.shape[:-2] != bboxes2.shape[:-2]:
156+
msg = "The batch dimension of bboxes must be the same."
157+
raise ValueError(msg)
158+
145159
batch_shape = bboxes1.shape[:-2]
146160

147161
rows = bboxes1.size(-2)
148162
cols = bboxes2.size(-2)
149163

150164
if rows * cols == 0:
165+
warnings.warn("No bboxes are provided! Returning empty boxes!", stacklevel=2)
151166
if is_aligned:
152-
return bboxes1.new((*batch_shape, rows))
153-
return bboxes1.new((*batch_shape, rows, cols))
167+
return bboxes1.new(batch_shape + (rows,)) # noqa: RUF005
168+
return bboxes1.new(batch_shape + (rows, cols)) # noqa: RUF005
154169

155170
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
156171
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])

src/otx/algo/instance_segmentation/utils/structures/mask/mask_target.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from __future__ import annotations
1010

11+
import warnings
12+
1113
import numpy as np
1214
import torch
1315
from datumaro.components.annotation import Polygon
@@ -62,16 +64,20 @@ def mask_target_single(
6264
meta_info: dict,
6365
) -> Tensor:
6466
"""Compute mask target for each positive proposal in the image."""
67+
mask_size = _pair(mask_size)
68+
if len(gt_masks) == 0:
69+
warnings.warn("No ground truth masks are provided!", stacklevel=2)
70+
return pos_proposals.new_zeros((0, *mask_size))
71+
6572
if isinstance(gt_masks[0], Polygon):
6673
crop_and_resize = crop_and_resize_polygons
6774
elif isinstance(gt_masks, tv_tensors.Mask):
6875
crop_and_resize = crop_and_resize_masks
6976
else:
70-
msg = f"Unsupported type of masks: {type(gt_masks[0])}"
71-
raise NotImplementedError(msg)
77+
warnings.warn("Unsupported ground truth mask type!", stacklevel=2)
78+
return pos_proposals.new_zeros((0, *mask_size))
7279

7380
device = pos_proposals.device
74-
mask_size = _pair(mask_size)
7581
num_pos = pos_proposals.size(0)
7682
if num_pos > 0:
7783
proposals_np = pos_proposals.cpu().numpy()

src/otx/core/data/pre_filtering.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ def pre_filtering(
4040
dataset = DmDataset.filter(dataset, is_valid_annot, filter_annotations=True)
4141
dataset = remove_unused_labels(dataset, data_format, ignore_index)
4242
if unannotated_items_ratio > 0:
43-
empty_items = [item.id for item in dataset if item.subset == "train" and len(item.annotations) == 0]
43+
empty_items = [
44+
item.id for item in dataset if item.subset in ("train", "TRAINING") and len(item.annotations) == 0
45+
]
4446
used_background_items = set(sample(empty_items, int(len(empty_items) * unannotated_items_ratio)))
4547

4648
return DmDataset.filter(
4749
dataset,
4850
lambda item: not (
49-
item.subset == "train" and len(item.annotations) == 0 and item.id not in used_background_items
51+
item.subset in ("train", "TRAINING") and len(item.annotations) == 0 and item.id not in used_background_items
5052
),
5153
)
5254

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
# Copyright (c) OpenMMLab. All rights reserved.
4+
from __future__ import annotations
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from otx.algo.common.utils.assigners.iou2d_calculator import BboxOverlaps2D
10+
from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
11+
12+
13+
def test_bbox_overlaps_2d(eps: float = 1e-7):
14+
def _construct_bbox(num_bbox: int | None = None) -> tuple[torch.Tensor, int]:
15+
img_h = int(np.random.randint(3, 1000))
16+
img_w = int(np.random.randint(3, 1000))
17+
if num_bbox is None:
18+
num_bbox = np.random.randint(1, 10)
19+
x1y1 = torch.rand((num_bbox, 2))
20+
x2y2 = torch.max(torch.rand((num_bbox, 2)), x1y1)
21+
bboxes = torch.cat((x1y1, x2y2), -1)
22+
bboxes[:, 0::2] *= img_w
23+
bboxes[:, 1::2] *= img_h
24+
return bboxes, num_bbox
25+
26+
# Test where is_aligned is True, bboxes.size(-1) == 5 (include score)
27+
self = BboxOverlaps2D()
28+
bboxes1, num_bbox = _construct_bbox()
29+
bboxes2, _ = _construct_bbox(num_bbox)
30+
bboxes1 = torch.cat((bboxes1, torch.rand((num_bbox, 1))), 1)
31+
bboxes2 = torch.cat((bboxes2, torch.rand((num_bbox, 1))), 1)
32+
gious = self(bboxes1, bboxes2, "giou", True)
33+
assert gious.size() == (num_bbox,), gious.size()
34+
assert torch.all(gious >= -1)
35+
assert torch.all(gious <= 1)
36+
37+
# Test where is_aligned is True, bboxes1.size(-2) == 0
38+
bboxes1 = torch.empty((0, 4))
39+
bboxes2 = torch.empty((0, 4))
40+
gious = self(bboxes1, bboxes2, "giou", True)
41+
assert gious.size() == (0,), gious.size()
42+
assert torch.all(gious == torch.empty((0,)))
43+
assert torch.all(gious >= -1)
44+
assert torch.all(gious <= 1)
45+
46+
# Test where is_aligned is True, and bboxes.ndims > 2
47+
bboxes1, num_bbox = _construct_bbox()
48+
bboxes2, _ = _construct_bbox(num_bbox)
49+
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
50+
# test assertion when batch dim is not the same
51+
with pytest.raises(ValueError, match="The batch dimension of bboxes must be the same."):
52+
self(bboxes1, bboxes2.unsqueeze(0).repeat(3, 1, 1), "giou", True)
53+
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
54+
gious = self(bboxes1, bboxes2, "giou", True)
55+
assert torch.all(gious >= -1)
56+
assert torch.all(gious <= 1)
57+
assert gious.size() == (2, num_bbox)
58+
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1, 1)
59+
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1, 1)
60+
gious = self(bboxes1, bboxes2, "giou", True)
61+
assert torch.all(gious >= -1)
62+
assert torch.all(gious <= 1)
63+
assert gious.size() == (2, 2, num_bbox)
64+
65+
# Test where is_aligned is False
66+
bboxes1, num_bbox1 = _construct_bbox()
67+
bboxes2, num_bbox2 = _construct_bbox()
68+
gious = self(bboxes1, bboxes2, "giou")
69+
assert torch.all(gious >= -1)
70+
assert torch.all(gious <= 1)
71+
assert gious.size() == (num_bbox1, num_bbox2)
72+
73+
# Test where is_aligned is False, and bboxes.ndims > 2
74+
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
75+
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
76+
gious = self(bboxes1, bboxes2, "giou")
77+
assert torch.all(gious >= -1)
78+
assert torch.all(gious <= 1)
79+
assert gious.size() == (2, num_bbox1, num_bbox2)
80+
bboxes1 = bboxes1.unsqueeze(0)
81+
bboxes2 = bboxes2.unsqueeze(0)
82+
gious = self(bboxes1, bboxes2, "giou")
83+
assert torch.all(gious >= -1)
84+
assert torch.all(gious <= 1)
85+
assert gious.size() == (1, 2, num_bbox1, num_bbox2)
86+
87+
# Test where is_aligned is False, bboxes1.size(-2) == 0
88+
gious = self(torch.empty(1, 2, 0, 4), bboxes2, "giou")
89+
assert torch.all(gious == torch.empty(1, 2, 0, bboxes2.size(-2)))
90+
assert torch.all(gious >= -1)
91+
assert torch.all(gious <= 1)
92+
93+
# test allclose between bbox_overlaps and the original official
94+
# implementation.
95+
bboxes1 = torch.FloatTensor(
96+
[
97+
[0, 0, 10, 10],
98+
[10, 10, 20, 20],
99+
[32, 32, 38, 42],
100+
],
101+
)
102+
bboxes2 = torch.FloatTensor(
103+
[
104+
[0, 0, 10, 20],
105+
[0, 10, 10, 19],
106+
[10, 10, 20, 20],
107+
],
108+
)
109+
gious = bbox_overlaps(bboxes1, bboxes2, "giou", is_aligned=True, eps=eps)
110+
gious = gious.numpy().round(4)
111+
# the gt is got with four decimal precision.
112+
expected_gious = np.array([0.5000, -0.0500, -0.8214])
113+
assert np.allclose(gious, expected_gious, rtol=0, atol=eps)
114+
115+
# test mode 'iof'
116+
ious = bbox_overlaps(bboxes1, bboxes2, "iof", is_aligned=True, eps=eps)
117+
assert torch.all(ious >= -1)
118+
assert torch.all(ious <= 1)
119+
assert ious.size() == (bboxes1.size(0),)
120+
ious = bbox_overlaps(bboxes1, bboxes2, "iof", eps=eps)
121+
assert torch.all(ious >= -1)
122+
assert torch.all(ious <= 1)
123+
assert ious.size() == (bboxes1.size(0), bboxes2.size(0))

0 commit comments

Comments
 (0)