Skip to content

[RFC][core][V1] generalize structured output manager and backends #17503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt

from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
from vllm.v1.structured_output.backend_types import (
StructuredOutputBatchMetaData)


@dataclass
Expand Down Expand Up @@ -144,11 +144,10 @@ class SchedulerOutput:
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]

# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
# Meta data for structured output batches
# By default this holds only the structured_output_request_ids
# but backends may extend this to hold more data for the batch
structured_output_meta: Optional[StructuredOutputBatchMetaData]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
structured_output_meta: Optional[StructuredOutputBatchMetaData]
structured_output_metadata: Optional[StructuredOutputBatchMetadata]

apriori assumption is that we will always deal with batch in a scheduler v1 design, hence I don't think we will need to explicitly imply the name to be BatchMetadata. However, I do see the point of having a verbose naming here.


# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
22 changes: 12 additions & 10 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,15 @@ def schedule(self) -> SchedulerOutput:
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))

grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
if self.structured_output_manager.backend is not None:
structured_output_meta = self.structured_output_manager.init_batch(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
else:
structured_output_meta = None

# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
Expand Down Expand Up @@ -548,9 +552,7 @@ def schedule(self) -> SchedulerOutput:
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
structured_output_meta=structured_output_meta)

# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
Expand Down Expand Up @@ -784,8 +786,8 @@ def update_from_output(
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
self.structured_output_manager.accept_tokens(
request, req_id, new_token_ids)

# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
Expand Down
60 changes: 2 additions & 58 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
from vllm.v1.structured_output import StructuredOutputManager


class Processor:
Expand Down Expand Up @@ -81,7 +78,7 @@ def _validate_sampling_params(
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
StructuredOutputManager.validate_request(params, self.vllm_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that much add to for this to be simply a validation functions.

If you want this validation function, then probably it is more useful to create a structured_outputs/utils.py. We are going to move some utilities functions from v0 to v1 soon anw, so might be good to have a utils file live somewhere in v1/structured_outputs

self._validate_logit_bias(params)

if params.allowed_token_ids is None:
Expand Down Expand Up @@ -148,59 +145,6 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return

engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
else:
params.guided_decoding.backend = engine_level_backend

# Request content validation
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True

def process_inputs(
self,
request_id: str,
Expand Down
Loading