-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[RFC][core][V1] generalize structured output manager and backends #17503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,10 +21,7 @@ | |
from vllm.transformers_utils.tokenizer_group import TokenizerGroup | ||
from vllm.v1.engine import EngineCoreRequest | ||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache | ||
from vllm.v1.structured_output.backend_guidance import ( | ||
validate_guidance_grammar) | ||
from vllm.v1.structured_output.backend_xgrammar import ( | ||
validate_xgrammar_grammar) | ||
from vllm.v1.structured_output import StructuredOutputManager | ||
|
||
|
||
class Processor: | ||
|
@@ -81,7 +78,7 @@ def _validate_sampling_params( | |
params: SamplingParams, | ||
lora_request: Optional[LoRARequest], | ||
) -> None: | ||
self._validate_structured_output(params) | ||
StructuredOutputManager.validate_request(params, self.vllm_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see that much add to for this to be simply a validation functions. If you want this validation function, then probably it is more useful to create a |
||
self._validate_logit_bias(params) | ||
|
||
if params.allowed_token_ids is None: | ||
|
@@ -148,59 +145,6 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: | |
raise ValueError(f"Got lora_request {lora_request} but LoRA is " | ||
"not enabled!") | ||
|
||
def _validate_structured_output(self, params: SamplingParams) -> None: | ||
if not params.guided_decoding or not self.decoding_config: | ||
return | ||
|
||
engine_level_backend = self.decoding_config.backend | ||
if params.guided_decoding.backend: | ||
# Request-level backend selection is not supported in V1. | ||
# The values may differ if `params` is reused and was set | ||
# to a specific backend based on `auto` behavior in a previous | ||
# request. We remember that it was set as a result of `auto` | ||
# using the `_auto` option set on the backend in the params. | ||
if (params.guided_decoding.backend != engine_level_backend | ||
and not (engine_level_backend == "auto" | ||
and params.guided_decoding.backend_was_auto)): | ||
raise ValueError( | ||
"Request-level structured output backend selection is no " | ||
"longer supported. The request specified " | ||
f"'{params.guided_decoding.backend}', but vLLM was " | ||
f"initialised with '{engine_level_backend}'. This error " | ||
"can be resolved by removing backend selection from the " | ||
"request.") | ||
else: | ||
params.guided_decoding.backend = engine_level_backend | ||
|
||
# Request content validation | ||
if engine_level_backend.startswith("xgrammar"): | ||
# xgrammar with no fallback | ||
validate_xgrammar_grammar(params) | ||
elif engine_level_backend.startswith("guidance"): | ||
# TODO: ideally we would have the LLTokenizer here as Lark syntax | ||
# allows <|special_token|> and similar, see | ||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens | ||
# Without tokenizer these are disallowed in grammars. | ||
validate_guidance_grammar(params, tokenizer=None) | ||
else: | ||
# NOTE: engine_level_backend must be "auto" here, because we have | ||
# checked supported_backends above. | ||
# "auto" is an opt-in to opinionated behavior where we try to | ||
# choose a backend based on request contents. This is not the | ||
# default as it is less predictable and subject to change | ||
# between releases as feature support changes. | ||
try: | ||
validate_xgrammar_grammar(params) | ||
params.guided_decoding.backend = "xgrammar" | ||
except ValueError: | ||
# The request either failed validation | ||
# or includes some jsonschema feature(s) that | ||
# are not supported in xgrammar. Fall back to guidance. | ||
validate_guidance_grammar(params, tokenizer=None) | ||
params.guided_decoding.backend = "guidance" | ||
# Remember that this backend was set automatically | ||
params.guided_decoding.backend_was_auto = True | ||
|
||
def process_inputs( | ||
self, | ||
request_id: str, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apriori assumption is that we will always deal with batch in a scheduler v1 design, hence I don't think we will need to explicitly imply the name to be
BatchMetadata
. However, I do see the point of having a verbose naming here.