Skip to content

Commit 714292e

Browse files
generalize structured output manager and backends
Signed-off-by: william-baker-inflection <william.baker@inflection.ai>
1 parent 8267f99 commit 714292e

13 files changed

+829
-385
lines changed

vllm/v1/core/sched/output.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from typing import TYPE_CHECKING, Optional
88

99
if TYPE_CHECKING:
10-
import numpy as np
11-
import numpy.typing as npt
1210

1311
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1412
KVConnectorMetadata)
1513
from vllm.lora.request import LoRARequest
1614
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
1715
from vllm.sampling_params import SamplingParams
1816
from vllm.v1.request import Request
17+
from vllm.v1.structured_output.backend_types import (
18+
StructuredOutputBatchMetaData)
1919

2020

2121
@dataclass
@@ -144,11 +144,10 @@ class SchedulerOutput:
144144
# Used to free the encoder cache.
145145
free_encoder_input_ids: list[tuple[str, int]]
146146

147-
# Dict of request ids to their index within the batch
148-
# for filling the next token bitmask
149-
structured_output_request_ids: dict[str, int]
150-
# the bitmask for the whole batch
151-
grammar_bitmask: Optional[npt.NDArray[np.int32]]
147+
# Meta data for structured output batches
148+
# By default this holds only the structured_output_request_ids
149+
# but backends may extend this to hold more data for the batch
150+
structured_output_meta: Optional[StructuredOutputBatchMetaData]
152151

153152
# KV Cache Connector metadata.
154153
kv_connector_metadata: Optional[KVConnectorMetadata] = None

vllm/v1/core/sched/scheduler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,15 @@ def schedule(self) -> SchedulerOutput:
505505
self.kv_cache_manager.get_num_common_prefix_blocks(
506506
any_request, len(self.running)))
507507

508-
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
509-
self.requests,
510-
structured_output_request_ids,
511-
scheduled_spec_decode_tokens,
512-
)
508+
if self.structured_output_manager.backend is not None:
509+
structured_output_meta = self.structured_output_manager.init_batch(
510+
self.requests,
511+
structured_output_request_ids,
512+
scheduled_spec_decode_tokens,
513+
)
514+
else:
515+
structured_output_meta = None
516+
513517
# Construct the scheduler output.
514518
new_reqs_data = [
515519
NewRequestData.from_request(req,
@@ -548,9 +552,7 @@ def schedule(self) -> SchedulerOutput:
548552
# the previous and the current steps.
549553
finished_req_ids=self.finished_req_ids,
550554
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
551-
structured_output_request_ids=structured_output_request_ids,
552-
grammar_bitmask=grammar_bitmask,
553-
)
555+
structured_output_meta=structured_output_meta)
554556

555557
# NOTE(Kuntai): this function is designed for multiple purposes:
556558
# 1. Plan the KV cache store
@@ -784,8 +786,8 @@ def update_from_output(
784786
# NOTE: structured_output_request
785787
# should not be None if use_structured_output, we have
786788
# check above, so safe to ignore type warning
787-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
788-
req_id, new_token_ids)
789+
self.structured_output_manager.accept_tokens(
790+
request, req_id, new_token_ids)
789791

790792
# Add newly generated spec token ids to the request.
791793
if spec_token_ids is not None:

vllm/v1/engine/processor.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
2222
from vllm.v1.engine import EngineCoreRequest
2323
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
24-
from vllm.v1.structured_output.backend_guidance import (
25-
validate_guidance_grammar)
26-
from vllm.v1.structured_output.backend_xgrammar import (
27-
validate_xgrammar_grammar)
24+
from vllm.v1.structured_output import StructuredOutputManager
2825

2926

3027
class Processor:
@@ -81,7 +78,7 @@ def _validate_sampling_params(
8178
params: SamplingParams,
8279
lora_request: Optional[LoRARequest],
8380
) -> None:
84-
self._validate_structured_output(params)
81+
StructuredOutputManager.validate_request(params, self.vllm_config)
8582
self._validate_logit_bias(params)
8683

8784
if params.allowed_token_ids is None:
@@ -148,59 +145,6 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
148145
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
149146
"not enabled!")
150147

151-
def _validate_structured_output(self, params: SamplingParams) -> None:
152-
if not params.guided_decoding or not self.decoding_config:
153-
return
154-
155-
engine_level_backend = self.decoding_config.backend
156-
if params.guided_decoding.backend:
157-
# Request-level backend selection is not supported in V1.
158-
# The values may differ if `params` is reused and was set
159-
# to a specific backend based on `auto` behavior in a previous
160-
# request. We remember that it was set as a result of `auto`
161-
# using the `_auto` option set on the backend in the params.
162-
if (params.guided_decoding.backend != engine_level_backend
163-
and not (engine_level_backend == "auto"
164-
and params.guided_decoding.backend_was_auto)):
165-
raise ValueError(
166-
"Request-level structured output backend selection is no "
167-
"longer supported. The request specified "
168-
f"'{params.guided_decoding.backend}', but vLLM was "
169-
f"initialised with '{engine_level_backend}'. This error "
170-
"can be resolved by removing backend selection from the "
171-
"request.")
172-
else:
173-
params.guided_decoding.backend = engine_level_backend
174-
175-
# Request content validation
176-
if engine_level_backend.startswith("xgrammar"):
177-
# xgrammar with no fallback
178-
validate_xgrammar_grammar(params)
179-
elif engine_level_backend.startswith("guidance"):
180-
# TODO: ideally we would have the LLTokenizer here as Lark syntax
181-
# allows <|special_token|> and similar, see
182-
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
183-
# Without tokenizer these are disallowed in grammars.
184-
validate_guidance_grammar(params, tokenizer=None)
185-
else:
186-
# NOTE: engine_level_backend must be "auto" here, because we have
187-
# checked supported_backends above.
188-
# "auto" is an opt-in to opinionated behavior where we try to
189-
# choose a backend based on request contents. This is not the
190-
# default as it is less predictable and subject to change
191-
# between releases as feature support changes.
192-
try:
193-
validate_xgrammar_grammar(params)
194-
params.guided_decoding.backend = "xgrammar"
195-
except ValueError:
196-
# The request either failed validation
197-
# or includes some jsonschema feature(s) that
198-
# are not supported in xgrammar. Fall back to guidance.
199-
validate_guidance_grammar(params, tokenizer=None)
200-
params.guided_decoding.backend = "guidance"
201-
# Remember that this backend was set automatically
202-
params.guided_decoding.backend_was_auto = True
203-
204148
def process_inputs(
205149
self,
206150
request_id: str,

0 commit comments

Comments
 (0)