Skip to content

Commit ae7a949

Browse files
committed
fix ut
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent de6a131 commit ae7a949

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,18 @@ class TestAscendMLAImpl(TestBase):
238238
new_callable=lambda: MagicMock(spec=GroupCoordinator))
239239
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
240240
return_value=2)
241-
@patch("vllm.config.get_current_vllm_config")
241+
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
242242
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
243-
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
243+
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, mock_tp):
244244
mock_tp.world_size = 2
245+
vllm_config = MagicMock()
245246
speculative_config = MagicMock()
247+
model_config = MagicMock()
246248
speculative_config.num_speculative_tokens = 4
247249
vllm_config.speculative_config = speculative_config
250+
model_config.dtype = torch.float16
251+
vllm_config.model_config = model_config
252+
get_current_vllm_config.return_value = vllm_config
248253

249254
num_heads = 256
250255
head_size = 1024

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def __init__(
492492
1) # 512: mask only support 512
493493

494494
# Adapt torch air graph mode with spec decoding.
495-
speculative_config = get_current_vllm_config().speculative_config
495+
speculative_config = vllm_config.speculative_config
496496
if speculative_config is not None:
497497
self.spec_token_num = speculative_config.num_speculative_tokens
498498
assert self.spec_token_num > 0

0 commit comments

Comments
 (0)