Skip to content

Commit 80cb38e

Browse files
Adjust clamping for rotated bboxes (#9112)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent e347ef9 commit 80cb38e

File tree

6 files changed

+222
-98
lines changed

6 files changed

+222
-98
lines changed

test/common_utils.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2222
from torchvision import io, tv_tensors
2323
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24-
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional import to_image, to_pil_image
2525

2626

2727
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -410,7 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413-
clamping_mode="hard", # TODOBB
413+
clamping_mode="soft",
414414
num_boxes=1,
415415
dtype=None,
416416
device="cpu",
@@ -469,21 +469,6 @@ def sample_position(values, max_value):
469469
else:
470470
raise ValueError(f"Format {format} is not supported")
471471
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
472-
if tv_tensors.is_rotated_bounding_format(format):
473-
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
474-
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
475-
# numerical issues during the testing
476-
buffer = 4
477-
out_boxes = clamp_bounding_boxes(
478-
out_boxes,
479-
format=format,
480-
canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer),
481-
clamping_mode=clamping_mode,
482-
)
483-
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
484-
out_boxes[:, :2] += buffer // 2
485-
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
486-
out_boxes[:, :] += buffer // 2
487472
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
488473

489474

test/test_transforms_v2.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def affine_bounding_boxes(bounding_boxes):
551551
),
552552
format=format,
553553
canvas_size=canvas_size,
554+
clamping_mode=clamping_mode,
554555
)
555556

556557

@@ -639,6 +640,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
639640
).reshape(bounding_boxes.shape),
640641
format=format,
641642
canvas_size=canvas_size,
643+
clamping_mode=clamping_mode,
642644
)
643645

644646

@@ -1305,7 +1307,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
13051307
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
13061308
else reference_affine_bounding_boxes_helper
13071309
)
1308-
return helper(bounding_boxes, affine_matrix=affine_matrix)
1310+
return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False)
13091311

13101312
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
13111313
@pytest.mark.parametrize(
@@ -1914,7 +1916,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
19141916
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
19151917
else reference_affine_bounding_boxes_helper
19161918
)
1917-
return helper(bounding_boxes, affine_matrix=affine_matrix)
1919+
return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False)
19181920

19191921
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
19201922
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
@@ -2079,7 +2081,6 @@ def test_functional(self, make_input):
20792081
(F.rotate_image, torch.Tensor),
20802082
(F._geometry._rotate_image_pil, PIL.Image.Image),
20812083
(F.rotate_image, tv_tensors.Image),
2082-
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
20832084
(F.rotate_mask, tv_tensors.Mask),
20842085
(F.rotate_video, tv_tensors.Video),
20852086
(F.rotate_keypoints, tv_tensors.KeyPoints),
@@ -2229,29 +2230,26 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
22292230
clamp=False,
22302231
)
22312232

2232-
return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
2233-
bounding_boxes
2234-
)
2233+
return self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy).to(bounding_boxes)
22352234

22362235
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
22372236
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
22382237
@pytest.mark.parametrize("expand", [False, True])
22392238
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
22402239
def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
2241-
bounding_boxes = make_bounding_boxes(format=format)
2240+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")
22422241

22432242
actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
22442243
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
2245-
2246-
torch.testing.assert_close(actual, expected)
22472244
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
2245+
torch.testing.assert_close(actual, expected)
22482246

22492247
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
22502248
@pytest.mark.parametrize("expand", [False, True])
22512249
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
22522250
@pytest.mark.parametrize("seed", list(range(5)))
22532251
def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
2254-
bounding_boxes = make_bounding_boxes(format=format)
2252+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")
22552253

22562254
transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)
22572255

@@ -2262,9 +2260,8 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
22622260
actual = transform(bounding_boxes)
22632261

22642262
expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)
2265-
2266-
torch.testing.assert_close(actual, expected)
22672263
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
2264+
torch.testing.assert_close(actual, expected)
22682265

22692266
def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy):
22702267
x, y = recenter_xy
@@ -4349,7 +4346,6 @@ def test_functional(self, make_input):
43494346
(F.resized_crop_image, torch.Tensor),
43504347
(F._geometry._resized_crop_image_pil, PIL.Image.Image),
43514348
(F.resized_crop_image, tv_tensors.Image),
4352-
(F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes),
43534349
(F.resized_crop_mask, tv_tensors.Mask),
43544350
(F.resized_crop_video, tv_tensors.Video),
43554351
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
@@ -4415,6 +4411,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
44154411
[0, 0, 1],
44164412
],
44174413
)
4414+
44184415
affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]
44194416

44204417
helper = (
@@ -4423,15 +4420,15 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
44234420
else reference_affine_bounding_boxes_helper
44244421
)
44254422

4426-
return helper(
4427-
bounding_boxes,
4428-
affine_matrix=affine_matrix,
4429-
new_canvas_size=size,
4430-
)
4423+
return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=size, clamp=False)
44314424

44324425
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
44334426
def test_functional_bounding_boxes_correctness(self, format):
4434-
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
4427+
# Note that we don't want to clamp because in
4428+
# _reference_resized_crop_bounding_boxes we are fusing the crop and the
4429+
# resize operation, where none of the croppings happen - particularly,
4430+
# the intermediate one.
4431+
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode="none")
44354432

44364433
actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE)
44374434
expected = self._reference_resized_crop_bounding_boxes(

test/test_tv_tensors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,8 @@ def test_return_type_input():
406406
tv_tensors.set_return_type("typo")
407407

408408
tv_tensors.set_return_type("tensor")
409+
410+
411+
def test_box_clamping_mode_default():
412+
assert tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
413+
assert tv_tensors.BoundingBoxes([0, 0, 10, 10, 0], format="XYWHR", canvas_size=(100, 100)).clamping_mode == "soft"

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def resize_bounding_boxes(
522522
size: Optional[list[int]],
523523
max_size: Optional[int] = None,
524524
format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY,
525-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
525+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
526526
) -> tuple[torch.Tensor, tuple[int, int]]:
527527
# We set the default format as `tv_tensors.BoundingBoxFormat.XYXY`
528528
# to ensure backward compatibility.
@@ -1108,15 +1108,16 @@ def _affine_bounding_boxes_with_expand(
11081108
shear: list[float],
11091109
center: Optional[list[float]] = None,
11101110
expand: bool = False,
1111-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
1111+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
11121112
) -> tuple[torch.Tensor, tuple[int, int]]:
11131113
if bounding_boxes.numel() == 0:
11141114
return bounding_boxes, canvas_size
11151115

11161116
original_shape = bounding_boxes.shape
11171117
dtype = bounding_boxes.dtype
1118-
need_cast = not bounding_boxes.is_floating_point()
1119-
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
1118+
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
1119+
need_cast = dtype not in acceptable_dtypes
1120+
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
11201121
device = bounding_boxes.device
11211122
is_rotated = tv_tensors.is_rotated_bounding_format(format)
11221123
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
@@ -1210,7 +1211,7 @@ def affine_bounding_boxes(
12101211
scale: float,
12111212
shear: list[float],
12121213
center: Optional[list[float]] = None,
1213-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
1214+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
12141215
) -> torch.Tensor:
12151216
out_box, _ = _affine_bounding_boxes_with_expand(
12161217
bounding_boxes,
@@ -1448,6 +1449,7 @@ def rotate_bounding_boxes(
14481449
angle: float,
14491450
expand: bool = False,
14501451
center: Optional[list[float]] = None,
1452+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
14511453
) -> tuple[torch.Tensor, tuple[int, int]]:
14521454
return _affine_bounding_boxes_with_expand(
14531455
bounding_boxes,
@@ -1459,6 +1461,7 @@ def rotate_bounding_boxes(
14591461
shear=[0.0, 0.0],
14601462
center=center,
14611463
expand=expand,
1464+
clamping_mode=clamping_mode,
14621465
)
14631466

14641467

@@ -1473,6 +1476,7 @@ def _rotate_bounding_boxes_dispatch(
14731476
angle=angle,
14741477
expand=expand,
14751478
center=center,
1479+
clamping_mode=inpt.clamping_mode,
14761480
)
14771481
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
14781482

@@ -1739,7 +1743,7 @@ def pad_bounding_boxes(
17391743
canvas_size: tuple[int, int],
17401744
padding: list[int],
17411745
padding_mode: str = "constant",
1742-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
1746+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
17431747
) -> tuple[torch.Tensor, tuple[int, int]]:
17441748
if padding_mode not in ["constant"]:
17451749
# TODO: add support of other padding modes
@@ -1857,7 +1861,7 @@ def crop_bounding_boxes(
18571861
left: int,
18581862
height: int,
18591863
width: int,
1860-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
1864+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
18611865
) -> tuple[torch.Tensor, tuple[int, int]]:
18621866

18631867
# Crop or implicit pad if left and/or top have negative values:
@@ -2097,7 +2101,7 @@ def perspective_bounding_boxes(
20972101
startpoints: Optional[list[list[int]]],
20982102
endpoints: Optional[list[list[int]]],
20992103
coefficients: Optional[list[float]] = None,
2100-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
2104+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
21012105
) -> torch.Tensor:
21022106
if bounding_boxes.numel() == 0:
21032107
return bounding_boxes
@@ -2412,7 +2416,7 @@ def elastic_bounding_boxes(
24122416
format: tv_tensors.BoundingBoxFormat,
24132417
canvas_size: tuple[int, int],
24142418
displacement: torch.Tensor,
2415-
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
2419+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
24162420
) -> torch.Tensor:
24172421
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
24182422
if not isinstance(displacement, torch.Tensor):
@@ -2433,19 +2437,19 @@ def elastic_bounding_boxes(
24332437

24342438
original_shape = bounding_boxes.shape
24352439
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
2436-
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
2440+
intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
24372441

24382442
bounding_boxes = (
24392443
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
2440-
).reshape(-1, 8 if is_rotated else 4)
2444+
).reshape(-1, 5 if is_rotated else 4)
24412445

24422446
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
24432447
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
24442448
# This is not an exact inverse of the grid
24452449
inv_grid = id_grid.sub_(displacement)
24462450

24472451
# Get points from bboxes
2448-
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
2452+
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
24492453
points = points.reshape(-1, 2)
24502454
if points.is_floating_point():
24512455
points = points.ceil_()
@@ -2457,8 +2461,8 @@ def elastic_bounding_boxes(
24572461
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
24582462

24592463
if is_rotated:
2460-
transformed_points = transformed_points.reshape(-1, 8)
2461-
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
2464+
transformed_points = transformed_points.reshape(-1, 2)
2465+
out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype)
24622466
else:
24632467
transformed_points = transformed_points.reshape(-1, 4, 2)
24642468
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
@@ -2619,11 +2623,18 @@ def center_crop_bounding_boxes(
26192623
format: tv_tensors.BoundingBoxFormat,
26202624
canvas_size: tuple[int, int],
26212625
output_size: list[int],
2626+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
26222627
) -> tuple[torch.Tensor, tuple[int, int]]:
26232628
crop_height, crop_width = _center_crop_parse_output_size(output_size)
26242629
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
26252630
return crop_bounding_boxes(
2626-
bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
2631+
bounding_boxes,
2632+
format,
2633+
top=crop_top,
2634+
left=crop_left,
2635+
height=crop_height,
2636+
width=crop_width,
2637+
clamping_mode=clamping_mode,
26272638
)
26282639

26292640

@@ -2632,7 +2643,11 @@ def _center_crop_bounding_boxes_dispatch(
26322643
inpt: tv_tensors.BoundingBoxes, output_size: list[int]
26332644
) -> tv_tensors.BoundingBoxes:
26342645
output, canvas_size = center_crop_bounding_boxes(
2635-
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
2646+
inpt.as_subclass(torch.Tensor),
2647+
format=inpt.format,
2648+
canvas_size=inpt.canvas_size,
2649+
output_size=output_size,
2650+
clamping_mode=inpt.clamping_mode,
26362651
)
26372652
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
26382653

@@ -2779,17 +2794,29 @@ def resized_crop_bounding_boxes(
27792794
height: int,
27802795
width: int,
27812796
size: list[int],
2797+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
27822798
) -> tuple[torch.Tensor, tuple[int, int]]:
2783-
bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
2784-
return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size)
2799+
bounding_boxes, canvas_size = crop_bounding_boxes(
2800+
bounding_boxes, format, top, left, height, width, clamping_mode=clamping_mode
2801+
)
2802+
return resize_bounding_boxes(
2803+
bounding_boxes, format=format, canvas_size=canvas_size, size=size, clamping_mode=clamping_mode
2804+
)
27852805

27862806

27872807
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
27882808
def _resized_crop_bounding_boxes_dispatch(
27892809
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs
27902810
) -> tv_tensors.BoundingBoxes:
27912811
output, canvas_size = resized_crop_bounding_boxes(
2792-
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
2812+
inpt.as_subclass(torch.Tensor),
2813+
format=inpt.format,
2814+
top=top,
2815+
left=left,
2816+
height=height,
2817+
width=width,
2818+
size=size,
2819+
clamping_mode=inpt.clamping_mode,
27932820
)
27942821
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
27952822

0 commit comments

Comments
 (0)