Skip to content

Commit b3cc159

Browse files
author
libaokui
committed
fix ut
1 parent b72de59 commit b3cc159

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

tests/ut/ops/test_flash_comm1.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
# This file is a part of the vllm-ascend project.
1414
#
1515

16-
from unittest.mock import patch
1716

1817
import torch
1918
import importlib
2019
from tests.ut.base import TestBase
2120
from unittest.mock import MagicMock, patch
21+
2222
from vllm.distributed.parallel_state import GroupCoordinator
2323

2424
from vllm_ascend.ops import sequence_parallel
2525

2626

27-
class TestFusedExperts310(TestBase):
27+
class Test_Flash_Comm1(TestBase):
2828

2929
@patch('vllm.distributed.tensor_model_parallel_all_gather')
3030
@patch('vllm.distributed.tensor_model_parallel_reduce_scatter')
@@ -39,14 +39,11 @@ def test_test_flash_comm1(self, mock_TP,
3939
hidden_size = 128
4040
tp_size = 4
4141
hidden_states = torch.randn(num_tokens, hidden_size)
42+
4243
mock_tp_group = mock_get_tp_group.return_value
4344
assert mock_tp_group.world_size == 4 # 手动断言属性存在
4445
assert mock_tp_group.rank_in_group == 0
4546

46-
# mock_get_tp_group.return_value = MagicMock()
47-
# mock_get_tp_group.return_value.world_size = 4
48-
# mock_get_tp_group.return_value.rank_in_group = 0
49-
5047
lengths_sum_unpadding = hidden_states.shape[0]
5148
lengths_sum_padding = ((lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
5249
padding_flag = True

0 commit comments

Comments
 (0)