16
16
# This file is a part of the vllm-ascend project.
17
17
# Adapted from vllm-project/blob/main/tests/entrypoints/llm/test_accuracy.py
18
18
#
19
-
20
19
import gc
21
20
import multiprocessing
21
+ import os
22
22
import sys
23
+ import time
24
+ from vllm import SamplingParams
23
25
from multiprocessing import Queue
24
26
27
+ import signal
28
+ import requests
29
+ import subprocess
25
30
import lm_eval
26
31
import pytest
27
32
import torch
28
33
34
+ SERVER_HOST = "127.0.0.1"
35
+ SERVER_PORT = 8000
36
+ HEALTH_URL = f"http://{ SERVER_HOST } :{ SERVER_PORT } /health"
37
+ COMPLETIONS_URL = f"http://{ SERVER_HOST } :{ SERVER_PORT } /v1/completions"
38
+
29
39
# pre-trained model path on Hugging Face.
30
- MODEL_NAME = ["Qwen/Qwen2.5-0.5B-Instruct" , "Qwen/Qwen2.5-VL-3B-Instruct" ]
40
+ # Qwen/Qwen2.5-0.5B-Instruct: accuracy test for unimodal model.
41
+ # Qwen/Qwen2.5-VL-3B-Instruct: accuracy test for multimodal model.
42
+ # Qwen/Qwen3-30B-A3B: accuracy test for EP.
43
+ # deepseek-ai/DeepSeek-V2-Lite: accuracy test for TP.
44
+ MODEL_NAME = [
45
+ "Qwen/Qwen2.5-0.5B-Instruct" , "Qwen/Qwen2.5-VL-3B-Instruct" ,
46
+ "Qwen/Qwen3-30B-A3B" , "deepseek-ai/DeepSeek-V2-Lite"
47
+ ]
48
+ # Qwen/Qwen2.5-7B-Instruct: accuracy test for DP
49
+ MODEL_NAME_DP = ["Qwen/Qwen2.5-0.5B-Instruct" ]
50
+
31
51
# Benchmark configuration mapping models to evaluation tasks:
32
52
# - Text model: GSM8K (grade school math reasoning)
33
53
# - Vision-language model: MMMU Art & Design validation (multimodal understanding)
34
54
TASK = {
35
55
"Qwen/Qwen2.5-0.5B-Instruct" : "gsm8k" ,
36
- "Qwen/Qwen2.5-VL-3B-Instruct" : "mmmu_val_art_and_design"
56
+ "Qwen/Qwen2.5-VL-3B-Instruct" : "mmmu_val_art_and_design" ,
57
+ "Qwen/Qwen3-30B-A3B" : "gsm8k" ,
58
+ "deepseek-ai/DeepSeek-V2-Lite" : "gsm8k"
37
59
}
38
60
# Answer validation requiring format consistency.
39
61
FILTER = {
40
62
"Qwen/Qwen2.5-0.5B-Instruct" : "exact_match,strict-match" ,
41
- "Qwen/Qwen2.5-VL-3B-Instruct" : "acc,none"
63
+ "Qwen/Qwen2.5-VL-3B-Instruct" : "acc,none" ,
64
+ "Qwen/Qwen3-30B-A3B" : "exact_match,strict-match" ,
65
+ "deepseek-ai/DeepSeek-V2-Lite" : "exact_match,strict-match"
42
66
}
43
67
# 3% relative tolerance for numerical accuracy.
44
68
RTOL = 0.03
45
69
# Baseline accuracy after VLLM optimization.
46
70
EXPECTED_VALUE = {
47
71
"Qwen/Qwen2.5-0.5B-Instruct" : 0.316 ,
48
- "Qwen/Qwen2.5-VL-3B-Instruct" : 0.541
72
+ "Qwen/Qwen2.5-VL-3B-Instruct" : 0.541 ,
73
+ "Qwen/Qwen3-30B-A3B" : 0.888 ,
74
+ "deepseek-ai/DeepSeek-V2-Lite" : 0.376
49
75
}
50
76
# Maximum context length configuration for each model.
51
77
MAX_MODEL_LEN = {
52
78
"Qwen/Qwen2.5-0.5B-Instruct" : 4096 ,
53
- "Qwen/Qwen2.5-VL-3B-Instruct" : 8192
79
+ "Qwen/Qwen2.5-VL-3B-Instruct" : 8192 ,
80
+ "Qwen/Qwen3-30B-A3B" : 4096 ,
81
+ "deepseek-ai/DeepSeek-V2-Lite" : 4096
54
82
}
55
83
# Model types distinguishing text-only and vision-language models.
56
84
MODEL_TYPE = {
57
85
"Qwen/Qwen2.5-0.5B-Instruct" : "vllm" ,
58
- "Qwen/Qwen2.5-VL-3B-Instruct" : "vllm-vlm"
86
+ "Qwen/Qwen2.5-VL-3B-Instruct" : "vllm-vlm" ,
87
+ "Qwen/Qwen3-30B-A3B" : "vllm" ,
88
+ "deepseek-ai/DeepSeek-V2-Lite" : "vllm"
59
89
}
60
90
# wrap prompts in a chat-style template.
61
- APPLY_CHAT_TEMPLATE = {"vllm" : False , "vllm-vlm" : True }
91
+ APPLY_CHAT_TEMPLATE = {
92
+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
93
+ "Qwen/Qwen2.5-VL-3B-Instruct" : True ,
94
+ "Qwen/Qwen3-30B-A3B" : False ,
95
+ "deepseek-ai/DeepSeek-V2-Lite" : False
96
+ }
62
97
# Few-shot examples handling as multi-turn dialogues.
63
- FEWSHOT_AS_MULTITURN = {"vllm" : False , "vllm-vlm" : True }
98
+ FEWSHOT_AS_MULTITURN = {
99
+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
100
+ "Qwen/Qwen2.5-VL-3B-Instruct" : True ,
101
+ "Qwen/Qwen3-30B-A3B" : False ,
102
+ "deepseek-ai/DeepSeek-V2-Lite" : False
103
+ }
104
+ MORE_ARGS = {
105
+ "Qwen/Qwen2.5-0.5B-Instruct" :
106
+ None ,
107
+ "Qwen/Qwen2.5-VL-3B-Instruct" :
108
+ None ,
109
+ "Qwen/Qwen3-30B-A3B" :
110
+ "tensor_parallel_size=4,enable_expert_parallel=True,enforce_eager=True" ,
111
+ "deepseek-ai/DeepSeek-V2-Lite" :
112
+ "tensor_parallel_size=4,trust_remote_code=True,enforce_eager=True"
113
+ }
114
+
115
+ multiprocessing .set_start_method ("spawn" , force = True )
64
116
65
117
66
- def run_test (queue , model , max_model_len , model_type ):
118
+ def get_available_npu_count ():
119
+ return torch .npu .device_count ()
120
+
121
+
122
+ def run_test (queue , model , max_model_len , model_type , more_args ):
67
123
try :
68
124
if model_type == "vllm-vlm" :
69
125
model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
70
126
"dtype=auto,max_images=2" )
71
127
else :
72
128
model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
73
129
"dtype=auto" )
130
+ if more_args is not None :
131
+ model_args = f"{ model_args } ,{ more_args } "
74
132
results = lm_eval .simple_evaluate (
75
133
model = model_type ,
76
134
model_args = model_args ,
77
135
tasks = TASK [model ],
78
136
batch_size = "auto" ,
79
- apply_chat_template = APPLY_CHAT_TEMPLATE [model_type ],
80
- fewshot_as_multiturn = FEWSHOT_AS_MULTITURN [model_type ],
137
+ apply_chat_template = APPLY_CHAT_TEMPLATE [model ],
138
+ fewshot_as_multiturn = FEWSHOT_AS_MULTITURN [model ],
81
139
)
82
140
result = results ["results" ][TASK [model ]][FILTER [model ]]
83
141
print ("result:" , result )
84
142
queue .put (result )
85
143
except Exception as e :
86
- queue .put (e )
144
+ error_msg = f"{ type (e ).__name__ } : { str (e )} "
145
+ queue .put (error_msg )
87
146
sys .exit (1 )
88
147
finally :
89
148
gc .collect ()
@@ -93,19 +152,89 @@ def run_test(queue, model, max_model_len, model_type):
93
152
@pytest .mark .parametrize ("model" , MODEL_NAME )
94
153
@pytest .mark .parametrize ("VLLM_USE_V1" , ["0" , "1" ])
95
154
def test_lm_eval_accuracy (monkeypatch : pytest .MonkeyPatch , model , VLLM_USE_V1 ):
155
+ npu_count = get_available_npu_count ()
96
156
if model == "Qwen/Qwen2.5-VL-3B-Instruct" and VLLM_USE_V1 == "1" :
97
157
pytest .skip (
98
158
"Qwen2.5-VL-3B-Instruct is not supported when VLLM_USE_V1=1" )
99
- with monkeypatch .context () as m :
100
- m .setenv ("VLLM_USE_V1" , VLLM_USE_V1 )
159
+ if (model == "Qwen/Qwen2.5-VL-3B-Instruct"
160
+ or model == "Qwen/Qwen2.5-0.5B-Instruct" ) and npu_count != 1 :
161
+ pytest .skip (
162
+ "test accuarcy for Qwen2.5-0.5B-Instruct and Qwen2.5-VL-3B-Instruct when tp != 1"
163
+ )
164
+ if (model == "Qwen/Qwen3-30B-A3B"
165
+ or model == "deepseek-ai/DeepSeek-V2-Lite" ) and (
166
+ os .getenv ("VLLM_USE_V1" ) != "1" or npu_count != 4 ):
167
+ pytest .skip (
168
+ "test ep accuracy for Qwen/Qwen3-30B-A3B when VLLM_USE_V1=1 and tp=4"
169
+ )
170
+ with monkeypatch .context ():
101
171
result_queue : Queue [float ] = multiprocessing .Queue ()
102
172
p = multiprocessing .Process (target = run_test ,
103
173
args = (result_queue , model ,
104
174
MAX_MODEL_LEN [model ],
105
- MODEL_TYPE [model ]))
175
+ MODEL_TYPE [model ], MORE_ARGS [ model ] ))
106
176
p .start ()
107
177
p .join ()
108
178
result = result_queue .get ()
109
179
print (result )
110
180
assert (EXPECTED_VALUE [model ] - RTOL < result < EXPECTED_VALUE [model ] + RTOL ), \
111
181
f"Expected: { EXPECTED_VALUE [model ]} ±{ RTOL } | Measured: { result } "
182
+
183
+
184
+ @pytest .mark .parametrize ("max_tokens" , [10 ])
185
+ @pytest .mark .parametrize ("model" , MODEL_NAME_DP )
186
+ def test_lm_eval_accuracy_dp (model , max_tokens ):
187
+ npu_count = get_available_npu_count ()
188
+ if (model != "Qwen/Qwen2.5-0.5B-Instruct"
189
+ or os .getenv ("VLLM_USE_V1" ) != "1" or npu_count != 4 ):
190
+ pytest .skip (
191
+ "test accuracy for DP when model is Qwen2.5-0.5B-Instruct and engine is V1"
192
+ )
193
+
194
+ log_file = open ("accuracy.log" , "a" )
195
+ cmd = [
196
+ "vllm" , "serve" , model , "--tensor_parallel_size" , "2" ,
197
+ "--data_parallel_size" , "2"
198
+ ]
199
+ server_proc = subprocess .Popen (cmd ,
200
+ stdout = log_file ,
201
+ stderr = subprocess .DEVNULL )
202
+
203
+ try :
204
+ for _ in range (300 ):
205
+ try :
206
+ r = requests .get (HEALTH_URL , timeout = 1 )
207
+ if r .status_code == 200 :
208
+ break
209
+ except requests .exceptions .RequestException :
210
+ pass
211
+ time .sleep (1 )
212
+ else :
213
+ pytest .fail (
214
+ f"vLLM serve did not become healthy after 300s: { HEALTH_URL } " )
215
+
216
+ prompt = "bejing is a"
217
+ payload = {
218
+ "prompt" : prompt ,
219
+ "max_tokens" : max_tokens ,
220
+ "sampling_params" : {
221
+ "temperature" : 0.0 ,
222
+ "top_p" : 1.0 ,
223
+ "seed" : 123
224
+ }
225
+ }
226
+ resp = requests .post (COMPLETIONS_URL , json = payload , timeout = 30 )
227
+ resp .raise_for_status ()
228
+ data = resp .json ()
229
+
230
+ generated = data ["choices" ][0 ]["text" ].strip ()
231
+ expected = "city in north china, it has many famous attractions"
232
+ assert generated == expected , f"Expected `{ expected } `, got `{ generated } `"
233
+
234
+ finally :
235
+ server_proc .send_signal (signal .SIGINT )
236
+ try :
237
+ server_proc .wait (timeout = 10 )
238
+ except subprocess .TimeoutExpired :
239
+ server_proc .kill ()
240
+ server_proc .wait ()
0 commit comments