Skip to content

Commit 9235ee1

Browse files
authored
Allow users to choose the bbox clamping mode (#9128)
1 parent 6aee5ed commit 9235ee1

File tree

8 files changed

+257
-39
lines changed

8 files changed

+257
-39
lines changed

test/common_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413+
clamping_mode="hard", # TODOBB
413414
num_boxes=1,
414415
dtype=None,
415416
device="cpu",
@@ -474,13 +475,16 @@ def sample_position(values, max_value):
474475
# numerical issues during the testing
475476
buffer = 4
476477
out_boxes = clamp_bounding_boxes(
477-
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
478+
out_boxes,
479+
format=format,
480+
canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer),
481+
clamping_mode=clamping_mode,
478482
)
479483
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
480484
out_boxes[:, :2] += buffer // 2
481485
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
482486
out_boxes[:, :] += buffer // 2
483-
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)
487+
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
484488

485489

486490
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):

test/test_transforms_v2.py

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def adapt_fill(value, *, dtype):
492492
def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
493493
format = bounding_boxes.format
494494
canvas_size = new_canvas_size or bounding_boxes.canvas_size
495+
clamping_mode = bounding_boxes.clamping_mode
495496

496497
def affine_bounding_boxes(bounding_boxes):
497498
dtype = bounding_boxes.dtype
@@ -535,6 +536,7 @@ def affine_bounding_boxes(bounding_boxes):
535536
output,
536537
format=format,
537538
canvas_size=canvas_size,
539+
clamping_mode=clamping_mode,
538540
)
539541
else:
540542
# We leave the bounding box as float64 so the caller gets the full precision to perform any additional
@@ -557,6 +559,7 @@ def reference_affine_rotated_bounding_boxes_helper(
557559
):
558560
format = bounding_boxes.format
559561
canvas_size = new_canvas_size or bounding_boxes.canvas_size
562+
clamping_mode = bounding_boxes.clamping_mode
560563

561564
def affine_rotated_bounding_boxes(bounding_boxes):
562565
dtype = bounding_boxes.dtype
@@ -618,6 +621,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
618621
output.to(dtype=dtype, device=device),
619622
format=format,
620623
canvas_size=canvas_size,
624+
clamping_mode=clamping_mode,
621625
)
622626
if clamp
623627
else output.to(dtype=output.dtype, device=device)
@@ -831,7 +835,6 @@ def test_functional(self, size, make_input):
831835
(F.resize_image, torch.Tensor),
832836
(F._geometry._resize_image_pil, PIL.Image.Image),
833837
(F.resize_image, tv_tensors.Image),
834-
(F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
835838
(F.resize_mask, tv_tensors.Mask),
836839
(F.resize_video, tv_tensors.Video),
837840
(F.resize_keypoints, tv_tensors.KeyPoints),
@@ -3289,7 +3292,6 @@ def test_functional(self, make_input):
32893292
(F.elastic_image, torch.Tensor),
32903293
(F._geometry._elastic_image_pil, PIL.Image.Image),
32913294
(F.elastic_image, tv_tensors.Image),
3292-
(F.elastic_bounding_boxes, tv_tensors.BoundingBoxes),
32933295
(F.elastic_mask, tv_tensors.Mask),
32943296
(F.elastic_video, tv_tensors.Video),
32953297
(F.elastic_keypoints, tv_tensors.KeyPoints),
@@ -5126,6 +5128,7 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
51265128
def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints):
51275129
format = bounding_boxes.format
51285130
canvas_size = bounding_boxes.canvas_size
5131+
clamping_mode = bounding_boxes.clamping_mode
51295132
dtype = bounding_boxes.dtype
51305133
device = bounding_boxes.device
51315134
is_rotated = tv_tensors.is_rotated_bounding_format(format)
@@ -5226,6 +5229,7 @@ def perspective_bounding_boxes(bounding_boxes):
52265229
output,
52275230
format=format,
52285231
canvas_size=canvas_size,
5232+
clamping_mode=clamping_mode,
52295233
).to(dtype=dtype, device=device)
52305234

52315235
return tv_tensors.BoundingBoxes(
@@ -5506,29 +5510,35 @@ def test_correctness_image(self, mean, std, dtype, fn):
55065510

55075511
class TestClampBoundingBoxes:
55085512
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5513+
@pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft
55095514
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
55105515
@pytest.mark.parametrize("device", cpu_and_cuda())
5511-
def test_kernel(self, format, dtype, device):
5512-
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
5516+
def test_kernel(self, format, clamping_mode, dtype, device):
5517+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=clamping_mode, dtype=dtype, device=device)
55135518
check_kernel(
55145519
F.clamp_bounding_boxes,
55155520
bounding_boxes,
55165521
format=bounding_boxes.format,
55175522
canvas_size=bounding_boxes.canvas_size,
5523+
clamping_mode=clamping_mode,
55185524
)
55195525

55205526
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5521-
def test_functional(self, format):
5522-
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
5527+
@pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft
5528+
def test_functional(self, format, clamping_mode):
5529+
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format, clamping_mode=clamping_mode))
55235530

55245531
def test_errors(self):
55255532
input_tv_tensor = make_bounding_boxes()
55265533
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)
55275534
format, canvas_size = input_tv_tensor.format, input_tv_tensor.canvas_size
55285535

5529-
for format_, canvas_size_ in [(None, None), (format, None), (None, canvas_size)]:
5536+
for format_, canvas_size_, clamping_mode_ in itertools.product(
5537+
(format, None), (canvas_size, None), (input_tv_tensor.clamping_mode, None)
5538+
):
55305539
with pytest.raises(
5531-
ValueError, match="For pure tensor inputs, `format` and `canvas_size` have to be passed."
5540+
ValueError,
5541+
match="For pure tensor inputs, `format`, `canvas_size` and `clamping_mode` have to be passed.",
55325542
):
55335543
F.clamp_bounding_boxes(input_pure_tensor, format=format_, canvas_size=canvas_size_)
55345544

@@ -5541,6 +5551,103 @@ def test_errors(self):
55415551
def test_transform(self):
55425552
check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes())
55435553

5554+
@pytest.mark.parametrize("rotated", (True, False))
5555+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none"))
5556+
@pytest.mark.parametrize("clamping_mode", ("hard", "none", None)) # TODOBB add soft here.
5557+
@pytest.mark.parametrize("pass_pure_tensor", (True, False))
5558+
@pytest.mark.parametrize("fn", [F.clamp_bounding_boxes, transform_cls_to_functional(transforms.ClampBoundingBoxes)])
5559+
def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn):
5560+
# This test checks 2 things:
5561+
# - That passing clamping_mode=None to the clamp_bounding_boxes
5562+
# functional (or to the class) relies on the box's `.clamping_mode`
5563+
# attribute
5564+
# - That clamping happens when it should, and only when it should, i.e.
5565+
# when the clamping mode is not "none". It doesn't validate the
5566+
# nunmerical results, only that clamping happened. For that, we create
5567+
# a large 100x100 box inside of a small 10x10 image.
5568+
5569+
if pass_pure_tensor and fn is not F.clamp_bounding_boxes:
5570+
# Only the functional supports pure tensors, not the class
5571+
return
5572+
if pass_pure_tensor and clamping_mode is None:
5573+
# cannot leave clamping_mode=None when passing pure tensor
5574+
return
5575+
5576+
if rotated:
5577+
boxes = tv_tensors.BoundingBoxes(
5578+
[0, 0, 100, 100, 0], format="XYWHR", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
5579+
)
5580+
expected_clamped_output = torch.tensor([[0, 0, 10, 10, 0]])
5581+
else:
5582+
boxes = tv_tensors.BoundingBoxes(
5583+
[0, 100, 0, 100], format="XYXY", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
5584+
)
5585+
expected_clamped_output = torch.tensor([[0, 10, 0, 10]])
5586+
5587+
if pass_pure_tensor:
5588+
out = fn(
5589+
boxes.as_subclass(torch.Tensor),
5590+
format=boxes.format,
5591+
canvas_size=boxes.canvas_size,
5592+
clamping_mode=clamping_mode,
5593+
)
5594+
else:
5595+
out = fn(boxes, clamping_mode=clamping_mode)
5596+
5597+
clamping_mode_prevailing = constructor_clamping_mode if clamping_mode is None else clamping_mode
5598+
if clamping_mode_prevailing == "none":
5599+
assert_equal(boxes, out) # should be a pass-through
5600+
else:
5601+
assert_equal(out, expected_clamped_output)
5602+
5603+
5604+
class TestSetClampingMode:
5605+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5606+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft
5607+
@pytest.mark.parametrize("desired_clamping_mode", ("hard", "none")) # TODOBB add soft
5608+
def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode):
5609+
5610+
in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode)
5611+
out_boxes = transforms.SetClampingMode(clamping_mode=desired_clamping_mode)(in_boxes)
5612+
5613+
assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak
5614+
assert out_boxes.clamping_mode == desired_clamping_mode
5615+
5616+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5617+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft
5618+
def test_pipeline_no_leak(self, format, constructor_clamping_mode):
5619+
class AssertClampingMode(transforms.Transform):
5620+
def __init__(self, expected_clamping_mode):
5621+
super().__init__()
5622+
self.expected_clamping_mode = expected_clamping_mode
5623+
5624+
_transformed_types = (tv_tensors.BoundingBoxes,)
5625+
5626+
def transform(self, inpt, _):
5627+
assert inpt.clamping_mode == self.expected_clamping_mode
5628+
return inpt
5629+
5630+
t = transforms.Compose(
5631+
[
5632+
transforms.SetClampingMode("none"),
5633+
AssertClampingMode("none"),
5634+
transforms.SetClampingMode("hard"),
5635+
AssertClampingMode("hard"),
5636+
transforms.SetClampingMode("none"),
5637+
AssertClampingMode("none"),
5638+
transforms.ClampBoundingBoxes("hard"),
5639+
]
5640+
)
5641+
5642+
in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode)
5643+
out_boxes = t(in_boxes)
5644+
5645+
assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak
5646+
5647+
# assert that the output boxes clamping_mode is the one set by the last SetClampingMode.
5648+
# ClampBoundingBoxes doesn't set clamping_mode.
5649+
assert out_boxes.clamping_mode == "none"
5650+
55445651

55455652
class TestClampKeyPoints:
55465653
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ScaleJitter,
4242
TenCrop,
4343
)
44-
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat
44+
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat, SetClampingMode
4545
from ._misc import (
4646
ConvertImageDtype,
4747
GaussianBlur,

torchvision/transforms/v2/_meta.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, Union
1+
from typing import Any, Optional, Union
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
5+
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
56

67

78
class ConvertBoundingBoxFormat(Transform):
@@ -28,12 +29,19 @@ class ClampBoundingBoxes(Transform):
2829
2930
The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
3031
32+
Args:
33+
clamping_mode: TODOBB more docs. Default is None which relies on the input box' clamping_mode attribute.
34+
3135
"""
3236

37+
def __init__(self, clamping_mode: Optional[CLAMPING_MODE_TYPE] = None) -> None:
38+
super().__init__()
39+
self.clamping_mode = clamping_mode
40+
3341
_transformed_types = (tv_tensors.BoundingBoxes,)
3442

3543
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
36-
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
44+
return F.clamp_bounding_boxes(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]
3745

3846

3947
class ClampKeyPoints(Transform):
@@ -46,3 +54,19 @@ class ClampKeyPoints(Transform):
4654

4755
def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
4856
return F.clamp_keypoints(inpt) # type: ignore[return-value]
57+
58+
59+
class SetClampingMode(Transform):
60+
"""TODOBB"""
61+
62+
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
63+
super().__init__()
64+
# TODOBB validate mode
65+
self.clamping_mode = clamping_mode
66+
67+
_transformed_types = (tv_tensors.BoundingBoxes,)
68+
69+
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
70+
out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment]
71+
out.clamping_mode = self.clamping_mode
72+
return out

0 commit comments

Comments
 (0)