Skip to content

Adjust clamping for rotated bboxes #9112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image
from torchvision.transforms.v2.functional import to_image, to_pil_image


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Expand Down Expand Up @@ -410,7 +410,7 @@ def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
clamping_mode="hard", # TODOBB
clamping_mode="soft",
num_boxes=1,
dtype=None,
device="cpu",
Expand Down Expand Up @@ -469,21 +469,6 @@ def sample_position(values, max_value):
else:
raise ValueError(f"Format {format} is not supported")
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
if tv_tensors.is_rotated_bounding_format(format):
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
# numerical issues during the testing
buffer = 4
out_boxes = clamp_bounding_boxes(
out_boxes,
format=format,
canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer),
clamping_mode=clamping_mode,
)
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
out_boxes[:, :2] += buffer // 2
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
out_boxes[:, :] += buffer // 2
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)


Expand Down
35 changes: 16 additions & 19 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def affine_bounding_boxes(bounding_boxes):
),
format=format,
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)


Expand Down Expand Up @@ -639,6 +640,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
).reshape(bounding_boxes.shape),
format=format,
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)


Expand Down Expand Up @@ -1305,7 +1307,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)
return helper(bounding_boxes, affine_matrix=affine_matrix)
return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1914,7 +1916,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)
return helper(bounding_boxes, affine_matrix=affine_matrix)
return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
Expand Down Expand Up @@ -2079,7 +2081,6 @@ def test_functional(self, make_input):
(F.rotate_image, torch.Tensor),
(F._geometry._rotate_image_pil, PIL.Image.Image),
(F.rotate_image, tv_tensors.Image),
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
(F.rotate_keypoints, tv_tensors.KeyPoints),
Expand Down Expand Up @@ -2229,29 +2230,26 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
clamp=False,
)

return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
bounding_boxes
)
return self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy).to(bounding_boxes)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
bounding_boxes = make_bounding_boxes(format=format)
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")

actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
torch.testing.assert_close(actual, expected)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
bounding_boxes = make_bounding_boxes(format=format)
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")

transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)

Expand All @@ -2262,9 +2260,8 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
actual = transform(bounding_boxes)

expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
torch.testing.assert_close(actual, expected)

def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy):
x, y = recenter_xy
Expand Down Expand Up @@ -4349,7 +4346,6 @@ def test_functional(self, make_input):
(F.resized_crop_image, torch.Tensor),
(F._geometry._resized_crop_image_pil, PIL.Image.Image),
(F.resized_crop_image, tv_tensors.Image),
(F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resized_crop_mask, tv_tensors.Mask),
(F.resized_crop_video, tv_tensors.Video),
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
Expand Down Expand Up @@ -4415,6 +4411,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
[0, 0, 1],
],
)

affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]

helper = (
Expand All @@ -4423,15 +4420,15 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
else reference_affine_bounding_boxes_helper
)

return helper(
bounding_boxes,
affine_matrix=affine_matrix,
new_canvas_size=size,
)
return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=size, clamp=False)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
def test_functional_bounding_boxes_correctness(self, format):
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
# Note that we don't want to clamp because in
# _reference_resized_crop_bounding_boxes we are fusing the crop and the
# resize operation, where none of the croppings happen - particularly,
# the intermediate one.
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode="none")

actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE)
expected = self._reference_resized_crop_bounding_boxes(
Expand Down
5 changes: 5 additions & 0 deletions test/test_tv_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,8 @@ def test_return_type_input():
tv_tensors.set_return_type("typo")

tv_tensors.set_return_type("tensor")


def test_box_clamping_mode_default():
assert tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
assert tv_tensors.BoundingBoxes([0, 0, 10, 10, 0], format="XYWHR", canvas_size=(100, 100)).clamping_mode == "soft"
65 changes: 46 additions & 19 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def resize_bounding_boxes(
size: Optional[list[int]],
max_size: Optional[int] = None,
format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
# We set the default format as `tv_tensors.BoundingBoxFormat.XYXY`
# to ensure backward compatibility.
Expand Down Expand Up @@ -1108,15 +1108,16 @@ def _affine_bounding_boxes_with_expand(
shear: list[float],
center: Optional[list[float]] = None,
expand: bool = False,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
if bounding_boxes.numel() == 0:
return bounding_boxes, canvas_size

original_shape = bounding_boxes.shape
dtype = bounding_boxes.dtype
need_cast = not bounding_boxes.is_floating_point()
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
device = bounding_boxes.device
is_rotated = tv_tensors.is_rotated_bounding_format(format)
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
Expand Down Expand Up @@ -1210,7 +1211,7 @@ def affine_bounding_boxes(
scale: float,
shear: list[float],
center: Optional[list[float]] = None,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> torch.Tensor:
out_box, _ = _affine_bounding_boxes_with_expand(
bounding_boxes,
Expand Down Expand Up @@ -1448,6 +1449,7 @@ def rotate_bounding_boxes(
angle: float,
expand: bool = False,
center: Optional[list[float]] = None,
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
return _affine_bounding_boxes_with_expand(
bounding_boxes,
Expand All @@ -1459,6 +1461,7 @@ def rotate_bounding_boxes(
shear=[0.0, 0.0],
center=center,
expand=expand,
clamping_mode=clamping_mode,
)


Expand All @@ -1473,6 +1476,7 @@ def _rotate_bounding_boxes_dispatch(
angle=angle,
expand=expand,
center=center,
clamping_mode=inpt.clamping_mode,
)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)

Expand Down Expand Up @@ -1739,7 +1743,7 @@ def pad_bounding_boxes(
canvas_size: tuple[int, int],
padding: list[int],
padding_mode: str = "constant",
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
if padding_mode not in ["constant"]:
# TODO: add support of other padding modes
Expand Down Expand Up @@ -1857,7 +1861,7 @@ def crop_bounding_boxes(
left: int,
height: int,
width: int,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:

# Crop or implicit pad if left and/or top have negative values:
Expand Down Expand Up @@ -2097,7 +2101,7 @@ def perspective_bounding_boxes(
startpoints: Optional[list[list[int]]],
endpoints: Optional[list[list[int]]],
coefficients: Optional[list[float]] = None,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> torch.Tensor:
if bounding_boxes.numel() == 0:
return bounding_boxes
Expand Down Expand Up @@ -2412,7 +2416,7 @@ def elastic_bounding_boxes(
format: tv_tensors.BoundingBoxFormat,
canvas_size: tuple[int, int],
displacement: torch.Tensor,
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> torch.Tensor:
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
if not isinstance(displacement, torch.Tensor):
Expand All @@ -2433,19 +2437,19 @@ def elastic_bounding_boxes(

original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY

bounding_boxes = (
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
).reshape(-1, 8 if is_rotated else 4)
).reshape(-1, 5 if is_rotated else 4)

id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)

# Get points from bboxes
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = points.reshape(-1, 2)
if points.is_floating_point():
points = points.ceil_()
Expand All @@ -2457,8 +2461,8 @@ def elastic_bounding_boxes(
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

if is_rotated:
transformed_points = transformed_points.reshape(-1, 8)
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
transformed_points = transformed_points.reshape(-1, 2)
out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype)
else:
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
Expand Down Expand Up @@ -2619,11 +2623,18 @@ def center_crop_bounding_boxes(
format: tv_tensors.BoundingBoxFormat,
canvas_size: tuple[int, int],
output_size: list[int],
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
return crop_bounding_boxes(
bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
bounding_boxes,
format,
top=crop_top,
left=crop_left,
height=crop_height,
width=crop_width,
clamping_mode=clamping_mode,
)


Expand All @@ -2632,7 +2643,11 @@ def _center_crop_bounding_boxes_dispatch(
inpt: tv_tensors.BoundingBoxes, output_size: list[int]
) -> tv_tensors.BoundingBoxes:
output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
inpt.as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
output_size=output_size,
clamping_mode=inpt.clamping_mode,
)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)

Expand Down Expand Up @@ -2779,17 +2794,29 @@ def resized_crop_bounding_boxes(
height: int,
width: int,
size: list[int],
clamping_mode: CLAMPING_MODE_TYPE = "soft",
) -> tuple[torch.Tensor, tuple[int, int]]:
bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size)
bounding_boxes, canvas_size = crop_bounding_boxes(
bounding_boxes, format, top, left, height, width, clamping_mode=clamping_mode
)
return resize_bounding_boxes(
bounding_boxes, format=format, canvas_size=canvas_size, size=size, clamping_mode=clamping_mode
)


@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _resized_crop_bounding_boxes_dispatch(
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs
) -> tv_tensors.BoundingBoxes:
output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
inpt.as_subclass(torch.Tensor),
format=inpt.format,
top=top,
left=left,
height=height,
width=width,
size=size,
clamping_mode=inpt.clamping_mode,
)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)

Expand Down
Loading
Loading