|
18 | 18 | #
|
19 | 19 | import json
|
20 | 20 | import os
|
21 |
| -from typing import Any, Dict |
22 | 21 |
|
23 | 22 | import jsonschema
|
24 | 23 | import pytest
|
25 | 24 | import regex as re
|
26 |
| -from vllm.outputs import RequestOutput |
27 |
| -from vllm.sampling_params import SamplingParams |
28 | 25 |
|
29 | 26 | from vllm_ascend.utils import vllm_version_is
|
30 | 27 |
|
31 | 28 | if vllm_version_is("0.10.2"):
|
32 |
| - from vllm.sampling_params import \ |
33 |
| - GuidedDecodingParams as StructuredOutputsParams |
| 29 | + from vllm.sampling_params import GuidedDecodingParams, SamplingParams |
34 | 30 | else:
|
35 |
| - from vllm.sampling_params import StructuredOutputsParams |
| 31 | + from vllm.sampling_params import SamplingParams, StructuredOutputsParams |
| 32 | + |
| 33 | +from vllm.outputs import RequestOutput |
36 | 34 |
|
37 | 35 | from tests.e2e.conftest import VllmRunner
|
38 | 36 |
|
@@ -90,36 +88,31 @@ def sample_json_schema():
|
90 | 88 | }
|
91 | 89 |
|
92 | 90 |
|
93 |
| -def construct_sampling_params( |
94 |
| - struct_param, sampling_kwargs: Dict[str, Any]) -> SamplingParams: |
95 |
| - if vllm_version_is("0.10.2"): |
96 |
| - return SamplingParams(guided_decoding=struct_param, **sampling_kwargs) |
97 |
| - else: |
98 |
| - return SamplingParams(structured_outputs=struct_param, |
99 |
| - **sampling_kwargs) |
100 |
| - |
101 |
| - |
102 | 91 | @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
|
103 | 92 | def test_guided_json_completion(guided_decoding_backend: str,
|
104 | 93 | sample_json_schema):
|
105 |
| - struct_output_params = StructuredOutputsParams(json=sample_json_schema, ) |
106 |
| - sampling_params = construct_sampling_params(struct_output_params, { |
107 |
| - "temperature": 1.0, |
108 |
| - "max_tokens": 500, |
109 |
| - }) |
110 |
| - |
111 |
| - runner_kwargs: Dict[str, Any] = { |
112 |
| - "seed": 0, |
113 |
| - } |
114 | 94 | if vllm_version_is("0.10.2"):
|
115 |
| - runner_kwargs["guided_decoding_backend"] = guided_decoding_backend |
| 95 | + sampling_params = SamplingParams( |
| 96 | + temperature=1.0, |
| 97 | + max_tokens=500, |
| 98 | + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) |
| 99 | + runner_kwargs = { |
| 100 | + "seed": 0, |
| 101 | + "guided_decoding_backend": guided_decoding_backend, |
| 102 | + } |
116 | 103 | else:
|
117 |
| - runner_kwargs["structured_outputs_config"] = { |
118 |
| - "backend": guided_decoding_backend |
| 104 | + sampling_params = SamplingParams( |
| 105 | + temperature=1.0, |
| 106 | + max_tokens=500, |
| 107 | + structured_outputs=StructuredOutputsParams( |
| 108 | + json=sample_json_schema)) |
| 109 | + runner_kwargs = { |
| 110 | + "seed": 0, |
| 111 | + "structured_outputs_config": { |
| 112 | + "backend": guided_decoding_backend |
| 113 | + }, |
119 | 114 | }
|
120 |
| - |
121 |
| - with VllmRunner(MODEL_NAME, |
122 |
| - **runner_kwargs) as vllm_model: # type: ignore[arg-type] |
| 115 | + with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model: |
123 | 116 | prompts = [
|
124 | 117 | f"Give an example JSON for an employee profile "
|
125 | 118 | f"that fits this schema: {sample_json_schema}"
|
@@ -147,21 +140,25 @@ def test_guided_json_completion(guided_decoding_backend: str,
|
147 | 140 | def test_guided_regex(guided_decoding_backend: str, sample_regex):
|
148 | 141 | if guided_decoding_backend == "outlines":
|
149 | 142 | pytest.skip("Outlines doesn't support regex-based guided decoding.")
|
150 |
| - |
151 |
| - struct_output_params = StructuredOutputsParams(json=sample_regex, ) |
152 |
| - sampling_params = construct_sampling_params(struct_output_params, { |
153 |
| - "temperature": 0.8, |
154 |
| - "top_p": 0.95, |
155 |
| - }) |
156 |
| - |
157 |
| - runner_kwargs: Dict[str, Any] = { |
158 |
| - "seed": 0, |
159 |
| - } |
160 | 143 | if vllm_version_is("0.10.2"):
|
161 |
| - runner_kwargs["guided_decoding_backend"] = guided_decoding_backend |
| 144 | + sampling_params = SamplingParams( |
| 145 | + temperature=0.8, |
| 146 | + top_p=0.95, |
| 147 | + guided_decoding=GuidedDecodingParams(regex=sample_regex)) |
| 148 | + runner_kwargs = { |
| 149 | + "seed": 0, |
| 150 | + "guided_decoding_backend": guided_decoding_backend, |
| 151 | + } |
162 | 152 | else:
|
163 |
| - runner_kwargs["structured_outputs_config"] = { |
164 |
| - "backend": guided_decoding_backend |
| 153 | + sampling_params = SamplingParams( |
| 154 | + temperature=0.8, |
| 155 | + top_p=0.95, |
| 156 | + structured_outputs=StructuredOutputsParams(regex=sample_regex)) |
| 157 | + runner_kwargs = { |
| 158 | + "seed": 0, |
| 159 | + "structured_outputs_config": { |
| 160 | + "backend": guided_decoding_backend |
| 161 | + }, |
165 | 162 | }
|
166 | 163 |
|
167 | 164 | with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model:
|
|
0 commit comments