Skip to content

Commit 28edc36

Browse files
JC-ut0offline0806
authored andcommitted
[V1][BUGFIX][0.10.1] FIX mtp on main branch (vllm-project#2632)
### What this PR does / why we need it? Fix MTP torchair bug caused by torchair refactor and moe refactor Depends on PRs: fused moe fix: vllm-project#2627 torchair multi DP fix: vllm-project#2626 ### Does this PR introduce _any_ user-facing change? when dp is enabled, to run mtp online server, need to disable server log due to the current metrics does not support multi dp `--disable-log-stats` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@7c8271c Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: offline0806 <z00858301@china.huawei.com>
1 parent dbec5d4 commit 28edc36

File tree

4 files changed

+125
-4
lines changed

4 files changed

+125
-4
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import pytest
6+
from vllm import SamplingParams
7+
8+
from tests.e2e.conftest import VllmRunner
9+
from vllm_ascend.ascend_config import clear_ascend_config
10+
11+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
12+
13+
14+
@pytest.fixture
15+
def sampling_config():
16+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
17+
18+
19+
@pytest.fixture
20+
def model_name():
21+
return "wemaster/deepseek_mtp_main_random_bf16"
22+
23+
24+
def test_mtp_torchair_correctness(
25+
sampling_config: SamplingParams,
26+
model_name: str,
27+
):
28+
example_prompts = [
29+
"Hello, my name is",
30+
"The president of the United States is",
31+
"The capital of France is",
32+
"The future of AI is",
33+
]
34+
'''
35+
Compare the outputs of a original LLM and a speculative LLM
36+
should be the same when using mtp speculative decoding.
37+
'''
38+
clear_ascend_config()
39+
with VllmRunner(model_name,
40+
tensor_parallel_size=1,
41+
gpu_memory_utilization=0.7,
42+
max_model_len=256,
43+
enforce_eager=False,
44+
additional_config={
45+
"torchair_graph_config": {
46+
"enabled": True,
47+
"use_cached_graph": False,
48+
"graph_batch_sizes": [1, 2, 4],
49+
},
50+
}) as ref_llm:
51+
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
52+
clear_ascend_config()
53+
with VllmRunner(model_name,
54+
tensor_parallel_size=1,
55+
max_num_seqs=256,
56+
gpu_memory_utilization=0.7,
57+
distributed_executor_backend="mp",
58+
enable_expert_parallel=True,
59+
speculative_config={
60+
"method": "deepseek_mtp",
61+
"num_speculative_tokens": 1,
62+
},
63+
enforce_eager=False,
64+
max_model_len=2000,
65+
additional_config={
66+
"torchair_graph_config": {
67+
"enabled": True,
68+
"use_cached_graph": False,
69+
"graph_batch_sizes": [1, 2, 4],
70+
}
71+
}) as spec_llm:
72+
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
73+
74+
matches = 0
75+
misses = 0
76+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
77+
ref_token_ids = ref_output[0][0]
78+
spec_token_ids = spec_output[0][0]
79+
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
80+
matches += 1
81+
else:
82+
misses += 1
83+
print(f"ref_output: {ref_output[1][0]}")
84+
print(f"spec_output: {spec_output[1][0]}")
85+
86+
# Heuristic: expect at least 66% of the prompts to match exactly
87+
# Upon failure, inspect the outputs to check for inaccuracy.
88+
assert matches > int(0.66 * len(ref_outputs))
89+
del spec_llm
90+
clear_ascend_config()

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
26+
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
27+
from vllm_ascend.quantization.quantizer import W8A8Quantizer
2628
from vllm_ascend.torchair.ops.torchair_fused_moe import (
2729
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
2830
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
@@ -233,12 +235,28 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
233235
mock_quant_config = MagicMock()
234236
mock_quant_method = MockFusedMoEMethod()
235237
mock_quant_config.get_quant_method.return_value = mock_quant_method
238+
mock_quant_config.is_layer_skipped_ascend.return_value = False
239+
with patch(
240+
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
241+
return_value=W8A8Quantizer):
242+
moe = TorchairAscendFusedMoE(**default_moe_config,
243+
quant_config=mock_quant_config)
244+
245+
assert moe.quant_method is not None
246+
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
247+
248+
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
249+
mock_quant_config = MagicMock()
250+
mock_quant_method = MockFusedMoEMethod()
251+
mock_quant_config.get_quant_method.return_value = mock_quant_method
252+
mock_quant_config.is_layer_skipped_ascend.return_value = True
236253

237254
moe = TorchairAscendFusedMoE(**default_moe_config,
238255
quant_config=mock_quant_config)
239256

240257
assert moe.quant_method is not None
241-
assert moe.quant_method == mock_quant_method
258+
assert isinstance(moe.quant_method,
259+
TorchairAscendUnquantizedFusedMoEMethod)
242260

243261
@pytest.mark.parametrize(
244262
"others_param",

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from vllm_ascend.distributed.parallel_state import get_mc2_group
4646
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4747
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
48+
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
4849
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
4950
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5051
get_all_reduce_merge_state,
@@ -1055,7 +1056,13 @@ def __init__(
10551056
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
10561057
self.moe)
10571058
else:
1058-
self.quant_method = quant_config.get_quant_method(self, prefix)
1059+
if quant_config.is_layer_skipped_ascend(
1060+
prefix, quant_config.packed_modules_mapping):
1061+
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
1062+
self.moe)
1063+
else:
1064+
self.quant_method = AscendFusedMoEMethod(
1065+
quant_config, prefix, quant_config.packed_modules_mapping)
10591066

10601067
assert self.quant_method is not None
10611068

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
1919
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2020
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
21+
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
22+
TorchairDeepSeekMTP
2123
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
2224
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
2325

@@ -266,8 +268,12 @@ def load_model(self) -> None:
266268
with set_default_torch_dtype(
267269
draft_model_config.dtype), set_current_vllm_config(
268270
self.vllm_config):
269-
self.model = CustomDeepSeekMTP(
270-
vllm_config=self.vllm_config).to(target_device)
271+
if self.torchair_graph_enabled:
272+
self.model = TorchairDeepSeekMTP(
273+
vllm_config=self.vllm_config).to(target_device)
274+
else:
275+
self.model = CustomDeepSeekMTP(
276+
vllm_config=self.vllm_config).to(target_device)
271277

272278
draft_attn_layer_names = (
273279
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -

0 commit comments

Comments
 (0)