diff --git a/test/test_ops.py b/test/test_ops.py index 9cb0cddedf7..60044b47e68 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -917,6 +917,44 @@ def test_batched_nms_implementations(self, seed): empty = torch.empty((0,), dtype=torch.int64) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) + @pytest.mark.parametrize("op_name", ["nms", "batched_nms"]) + def test_nms_compile_tensor_iou_threshold(self, op_name): + """Test that NMS operations can be compiled when iou_threshold is a tensor""" + # This test addresses GitHub issue pytorch/pytorch#156722 + + if not is_compile_supported("cpu"): + pytest.skip("torch.compile not supported on this platform") + + if op_name == "nms": + + def model_fn(boxes, scores): + return ops.nms(boxes, scores, torch.tensor(0.5)) + + boxes = torch.rand(100, 4) + scores = torch.rand(100) + args = (boxes, scores) + else: # batched_nms + + def model_fn(boxes, scores, idxs): + return ops.batched_nms(boxes, scores, idxs, torch.tensor(0.5)) + + boxes = torch.rand(100, 4) + scores = torch.rand(100) + idxs = torch.randint(0, 10, (100,)) + args = (boxes, scores, idxs) + + # Test that compilation works + compiled_fn = torch.compile(model_fn, backend="eager") + + # Get expected result without compilation + expected = model_fn(*args) + + # Get result with compilation + result = compiled_fn(*args) + + # Results should match + torch.testing.assert_close(result, expected, msg=f"{op_name} compilation with tensor iou_threshold failed") + optests.generate_opcheck_tests( testcase=TestNMS, diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 7fb8192e1cd..146efb87b42 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -45,6 +45,11 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(nms) _assert_has_ops() + + # Handle case where iou_threshold is a tensor (e.g., from torch.tensor(0.5)) + if isinstance(iou_threshold, torch.Tensor): + iou_threshold = iou_threshold.item() + return torch.ops.torchvision.nms(boxes, scores, iou_threshold) @@ -74,6 +79,11 @@ def batched_nms( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(batched_nms) + + # Handle case where iou_threshold is a tensor (e.g., from torch.tensor(0.5)) + if isinstance(iou_threshold, torch.Tensor): + iou_threshold = iou_threshold.item() + # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 # and https://github.com/pytorch/vision/pull/8925