Skip to content

Commit ba70ee3

Browse files
committed
fix points_in_box bug.
1 parent 085f1d6 commit ba70ee3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

mmcv/ops/points_in_boxes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
4747
points_device = points.get_device()
4848
assert points_device == boxes.get_device(), \
4949
'Points and boxes should be put on the same device'
50-
if torch.cuda.current_device() != points_device:
51-
torch.cuda.set_device(points_device)
50+
if points_device != 'npu':
51+
if torch.cuda.current_device() != points_device:
52+
torch.cuda.set_device(points_device)
5253

5354
ext_module.points_in_boxes_part_forward(boxes.contiguous(),
5455
points.contiguous(),

0 commit comments

Comments
 (0)