Skip to content

Commit 29f8246

Browse files
committed
Fix guide decode
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent e8a6a8d commit 29f8246

File tree

2 files changed

+41
-44
lines changed

2 files changed

+41
-44
lines changed

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ jobs:
134134
pytest -sv tests/e2e/singlecard/test_camem.py
135135
pytest -sv tests/e2e/singlecard/test_chunked.py
136136
pytest -sv tests/e2e/singlecard/test_embedding.py
137-
#pytest -sv tests/e2e/singlecard/test_guided_decoding.py
137+
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
138138
#pytest -sv tests/e2e/singlecard/test_ilama_lora.py
139139
pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py
140140
pytest -sv tests/e2e/singlecard/test_quantization.py

tests/e2e/singlecard/test_guided_decoding.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,19 @@
1818
#
1919
import json
2020
import os
21-
from typing import Any, Dict
2221

2322
import jsonschema
2423
import pytest
2524
import regex as re
26-
from vllm.outputs import RequestOutput
27-
from vllm.sampling_params import SamplingParams
2825

2926
from vllm_ascend.utils import vllm_version_is
3027

3128
if vllm_version_is("0.10.2"):
32-
from vllm.sampling_params import \
33-
GuidedDecodingParams as StructuredOutputsParams
29+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
3430
else:
35-
from vllm.sampling_params import StructuredOutputsParams
31+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
32+
33+
from vllm.outputs import RequestOutput
3634

3735
from tests.e2e.conftest import VllmRunner
3836

@@ -90,36 +88,31 @@ def sample_json_schema():
9088
}
9189

9290

93-
def construct_sampling_params(
94-
struct_param, sampling_kwargs: Dict[str, Any]) -> SamplingParams:
95-
if vllm_version_is("0.10.2"):
96-
return SamplingParams(guided_decoding=struct_param, **sampling_kwargs)
97-
else:
98-
return SamplingParams(structured_outputs=struct_param,
99-
**sampling_kwargs)
100-
101-
10291
@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
10392
def test_guided_json_completion(guided_decoding_backend: str,
10493
sample_json_schema):
105-
struct_output_params = StructuredOutputsParams(json=sample_json_schema, )
106-
sampling_params = construct_sampling_params(struct_output_params, {
107-
"temperature": 1.0,
108-
"max_tokens": 500,
109-
})
110-
111-
runner_kwargs: Dict[str, Any] = {
112-
"seed": 0,
113-
}
11494
if vllm_version_is("0.10.2"):
115-
runner_kwargs["guided_decoding_backend"] = guided_decoding_backend
95+
sampling_params = SamplingParams(
96+
temperature=1.0,
97+
max_tokens=500,
98+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
99+
runner_kwargs = {
100+
"seed": 0,
101+
"guided_decoding_backend": guided_decoding_backend,
102+
}
116103
else:
117-
runner_kwargs["structured_outputs_config"] = {
118-
"backend": guided_decoding_backend
104+
sampling_params = SamplingParams(
105+
temperature=1.0,
106+
max_tokens=500,
107+
structured_outputs=StructuredOutputsParams(
108+
json=sample_json_schema))
109+
runner_kwargs = {
110+
"seed": 0,
111+
"structured_outputs_config": {
112+
"backend": guided_decoding_backend
113+
},
119114
}
120-
121-
with VllmRunner(MODEL_NAME,
122-
**runner_kwargs) as vllm_model: # type: ignore[arg-type]
115+
with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model:
123116
prompts = [
124117
f"Give an example JSON for an employee profile "
125118
f"that fits this schema: {sample_json_schema}"
@@ -147,21 +140,25 @@ def test_guided_json_completion(guided_decoding_backend: str,
147140
def test_guided_regex(guided_decoding_backend: str, sample_regex):
148141
if guided_decoding_backend == "outlines":
149142
pytest.skip("Outlines doesn't support regex-based guided decoding.")
150-
151-
struct_output_params = StructuredOutputsParams(json=sample_regex, )
152-
sampling_params = construct_sampling_params(struct_output_params, {
153-
"temperature": 0.8,
154-
"top_p": 0.95,
155-
})
156-
157-
runner_kwargs: Dict[str, Any] = {
158-
"seed": 0,
159-
}
160143
if vllm_version_is("0.10.2"):
161-
runner_kwargs["guided_decoding_backend"] = guided_decoding_backend
144+
sampling_params = SamplingParams(
145+
temperature=0.8,
146+
top_p=0.95,
147+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
148+
runner_kwargs = {
149+
"seed": 0,
150+
"guided_decoding_backend": guided_decoding_backend,
151+
}
162152
else:
163-
runner_kwargs["structured_outputs_config"] = {
164-
"backend": guided_decoding_backend
153+
sampling_params = SamplingParams(
154+
temperature=0.8,
155+
top_p=0.95,
156+
structured_outputs=StructuredOutputsParams(regex=sample_regex))
157+
runner_kwargs = {
158+
"seed": 0,
159+
"structured_outputs_config": {
160+
"backend": guided_decoding_backend
161+
},
165162
}
166163

167164
with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model:

0 commit comments

Comments
 (0)