From 714292e8a50b110e77f45099249be818554217ee Mon Sep 17 00:00:00 2001 From: william-baker-inflection Date: Fri, 6 Jun 2025 07:16:26 -0700 Subject: [PATCH] generalize structured output manager and backends Signed-off-by: william-baker-inflection --- vllm/v1/core/sched/output.py | 13 +- vllm/v1/core/sched/scheduler.py | 22 +- vllm/v1/engine/processor.py | 60 +--- vllm/v1/structured_output/__init__.py | 313 ++++++++++++------ .../structured_output/backend_bitmasking.py | 189 +++++++++++ vllm/v1/structured_output/backend_guidance.py | 21 +- vllm/v1/structured_output/backend_types.py | 60 ++-- vllm/v1/structured_output/backend_xgrammar.py | 23 +- vllm/v1/structured_output/worker_backend.py | 56 ++++ .../worker_backend_bitmasking_gpu.py | 109 ++++++ .../worker_backend_bitmasking_tpu.py | 169 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 76 +---- vllm/v1/worker/tpu_model_runner.py | 103 +----- 13 files changed, 829 insertions(+), 385 deletions(-) create mode 100644 vllm/v1/structured_output/backend_bitmasking.py create mode 100644 vllm/v1/structured_output/worker_backend.py create mode 100644 vllm/v1/structured_output/worker_backend_bitmasking_gpu.py create mode 100644 vllm/v1/structured_output/worker_backend_bitmasking_tpu.py diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b404c70eb1e4..01e4d52f079a 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -7,8 +7,6 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - import numpy as np - import numpy.typing as npt from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) @@ -16,6 +14,8 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.request import Request + from vllm.v1.structured_output.backend_types import ( + StructuredOutputBatchMetaData) @dataclass @@ -144,11 +144,10 @@ class SchedulerOutput: # Used to free the encoder cache. free_encoder_input_ids: list[tuple[str, int]] - # Dict of request ids to their index within the batch - # for filling the next token bitmask - structured_output_request_ids: dict[str, int] - # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + # Meta data for structured output batches + # By default this holds only the structured_output_request_ids + # but backends may extend this to hold more data for the batch + structured_output_meta: Optional[StructuredOutputBatchMetaData] # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f3b5c74829a9..632825aecdb2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -505,11 +505,15 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) - grammar_bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) + if self.structured_output_manager.backend is not None: + structured_output_meta = self.structured_output_manager.init_batch( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + else: + structured_output_meta = None + # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, @@ -548,9 +552,7 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, - ) + structured_output_meta=structured_output_meta) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store @@ -784,8 +786,8 @@ def update_from_output( # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + self.structured_output_manager.accept_tokens( + request, req_id, new_token_ids) # Add newly generated spec token ids to the request. if spec_token_ids is not None: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e28879d40460..f991ceca1bbe 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -21,10 +21,7 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache -from vllm.v1.structured_output.backend_guidance import ( - validate_guidance_grammar) -from vllm.v1.structured_output.backend_xgrammar import ( - validate_xgrammar_grammar) +from vllm.v1.structured_output import StructuredOutputManager class Processor: @@ -81,7 +78,7 @@ def _validate_sampling_params( params: SamplingParams, lora_request: Optional[LoRARequest], ) -> None: - self._validate_structured_output(params) + StructuredOutputManager.validate_request(params, self.vllm_config) self._validate_logit_bias(params) if params.allowed_token_ids is None: @@ -148,59 +145,6 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - def _validate_structured_output(self, params: SamplingParams) -> None: - if not params.guided_decoding or not self.decoding_config: - return - - engine_level_backend = self.decoding_config.backend - if params.guided_decoding.backend: - # Request-level backend selection is not supported in V1. - # The values may differ if `params` is reused and was set - # to a specific backend based on `auto` behavior in a previous - # request. We remember that it was set as a result of `auto` - # using the `_auto` option set on the backend in the params. - if (params.guided_decoding.backend != engine_level_backend - and not (engine_level_backend == "auto" - and params.guided_decoding.backend_was_auto)): - raise ValueError( - "Request-level structured output backend selection is no " - "longer supported. The request specified " - f"'{params.guided_decoding.backend}', but vLLM was " - f"initialised with '{engine_level_backend}'. This error " - "can be resolved by removing backend selection from the " - "request.") - else: - params.guided_decoding.backend = engine_level_backend - - # Request content validation - if engine_level_backend.startswith("xgrammar"): - # xgrammar with no fallback - validate_xgrammar_grammar(params) - elif engine_level_backend.startswith("guidance"): - # TODO: ideally we would have the LLTokenizer here as Lark syntax - # allows <|special_token|> and similar, see - # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens - # Without tokenizer these are disallowed in grammars. - validate_guidance_grammar(params, tokenizer=None) - else: - # NOTE: engine_level_backend must be "auto" here, because we have - # checked supported_backends above. - # "auto" is an opt-in to opinionated behavior where we try to - # choose a backend based on request contents. This is not the - # default as it is less predictable and subject to change - # between releases as feature support changes. - try: - validate_xgrammar_grammar(params) - params.guided_decoding.backend = "xgrammar" - except ValueError: - # The request either failed validation - # or includes some jsonschema feature(s) that - # are not supported in xgrammar. Fall back to guidance. - validate_guidance_grammar(params, tokenizer=None) - params.guided_decoding.backend = "guidance" - # Remember that this backend was set automatically - params.guided_decoding.backend_was_auto = True - def process_inputs( self, request_id: str, diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b2b0ee796954..f291782481ed 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -8,17 +8,26 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.reasoning import ReasoningParserManager +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_guidance import GuidanceBackend -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar) -from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend +from vllm.v1.structured_output.backend_guidance import ( + GuidanceBackend, validate_guidance_grammar) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, StructuredOutputBatchMetaData, + StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import ( + XgrammarBackend, validate_xgrammar_grammar) +from vllm.v1.structured_output.worker_backend import ( + StructuredOutputWorkerBackend) +from vllm.v1.structured_output.worker_backend_bitmasking_gpu import ( + BitmaskGPUBackend) +from vllm.v1.structured_output.worker_backend_bitmasking_tpu import ( + BitmaskTPUBackend) if TYPE_CHECKING: - import numpy as np - import numpy.typing as npt import torch from vllm.reasoning import ReasoningParser @@ -30,7 +39,12 @@ class StructuredOutputManager: - """Engine-level manager for structured output requests.""" + """Engine-level manager for structured output requests. + This manager holds a backend property used to initialise and + compile grammars + Each v1 request will then have the compiled grammar assigned to + request.structured_output_request.grammar + """ def __init__(self, vllm_config: VllmConfig): self.backend: Optional[StructuredOutputBackend] = None @@ -38,7 +52,6 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None - self._full_mask = torch.tensor(-1, dtype=torch.int32) # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. @@ -57,35 +70,63 @@ def __init__(self, vllm_config: VllmConfig): reasoning_backend) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) - def grammar_init(self, request: Request) -> None: - if request.structured_output_request is None: + @staticmethod + def get_worker_backend( + vllm_config: VllmConfig) -> StructuredOutputWorkerBackend: + if current_platform.is_tpu(): + return BitmaskTPUBackend(vllm_config) + else: + return BitmaskGPUBackend(vllm_config) + + def init_backend(self, backend: str) -> None: + """ + Initialize the backend for structured output processing. + This method is called when the engine starts up and is responsible + for setting up the backend for structured output requests. + """ + if self.backend is not None: return + if backend == "auto": + if self.vllm_config.decoding_config.backend != "auto": + backend = self.vllm_config.decoding_config.backend + else: + backend = "xgrammar" # default to xgrammar - if TYPE_CHECKING: - assert request.sampling_params.guided_decoding is not None + vocab_size = self.vllm_config.model_config.get_vocab_size() - # Initialize the backend the first time it is needed. - # - # NOTE: We only support a single backend. We do NOT support different - # backends on a per-request basis in V1 (for now, anyway...). - if self.backend is None: - backend = request.sampling_params.guided_decoding.backend - vocab_size = self.vllm_config.model_config.get_vocab_size() + if backend in ["xgrammar", "guidance"]: # Bitmasking Backends if backend == "xgrammar": self.backend = XgrammarBackend( self.vllm_config, tokenizer=self.tokenizer, vocab_size=vocab_size, + reasoner=self.reasoner, ) - elif backend == "guidance": - self.backend = GuidanceBackend( + else: # Guidance + self.backend = GuidanceBackend( # type: ignore[assignment] self.vllm_config, tokenizer=self.tokenizer, vocab_size=vocab_size, + reasoner=self.reasoner, ) - else: - raise ValueError( - f"Unsupported structured output backend: {backend}") + else: + raise ValueError( + f"Unsupported structured output backend: {backend}") + + def grammar_init(self, request: Request) -> None: + if request.structured_output_request is None: + return + + if TYPE_CHECKING: + assert request.sampling_params.guided_decoding is not None + + # Initialize the backend the first time it is needed. + # + # NOTE: We only support a single backend. We do NOT support different + # backends on a per-request basis in V1 (for now, anyway...). + if self.backend is None: + self.init_backend(request.sampling_params.guided_decoding.backend + ) # type: ignore[union-attr] grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] @@ -106,89 +147,63 @@ def _async_create_grammar( assert self.backend is not None return self.backend.compile_grammar(request_type, grammar_spec) - def grammar_bitmask( - self, - requests: dict[str, Request], + def accept_tokens(self, request: Request, req_id: str, + tokens: list[int]) -> bool: + """ + Validates whether the provided tokens are acceptable based on + the grammar defined in the structured output request. + + Called in v1.core.sched.Scheduler.update_from_output after + tokens have been accepted + Args: + request (Request): The request object containing the + structured output request and its associated grammar. + req_id (str): The unique identifier for the request. + tokens (list[int]): A list of integer tokens to be validated. + Returns: + bool: True if the FSM was advanced successfully. + False if the FSM failed to advance. + """ + assert request.structured_output_request is not None and \ + request.structured_output_request.grammar is not None + return request.structured_output_request.grammar.accept_tokens( + req_id, tokens) + + def init_batch( + self, requests: dict[str, Request], structured_output_request_ids: dict[str, int], - scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> Optional[npt.NDArray[np.int32]]: - # Prepare the structured output bitmask for this batch. + scheduled_spec_decode_tokens: dict[str, list[int]] + ) -> StructuredOutputBatchMetaData | None: + """ + Called in the v1/core/sched/Scheduler.schedule to initialize + the batch of requests. + At this point, we have completed scheduling for the current step. + The `structured_output_request_ids` dictionary maps request IDs + that use structured output to their corresponding indices in the + running queue. + Args: + requests (dict[str, Request]): A dictionary mapping request IDs + to their corresponding `Request` objects. + structured_output_request_ids (dict[str, int]): A dictionary mapping + request IDs that use structured output to their respective + indices in the running queue. + scheduled_spec_decode_tokens (dict[str, list[int]]): A dictionary + mapping request IDs to lists of token IDs that are scheduled + for decoding. + Returns: + StructuredOutputBatchMetaData: Metadata for the initialized batch + of structured output requests. + """ + + assert self.backend is not None if not structured_output_request_ids: return None - - max_num_spec_tokens = 0 - if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = \ - self.vllm_config.speculative_config.num_speculative_tokens - - if self._grammar_bitmask is None: - assert self.backend is not None - max_batch_size = self.vllm_config.scheduler_config.max_num_seqs - - # Allocate a bitmask for each token needing to be checked: - # one for each speculative position, and one more for the - # bonus token / non-speculative token. - self._grammar_bitmask = \ - self.backend.allocate_token_bitmask( - max_batch_size * (1 + max_num_spec_tokens)) - - bitmask_tensor = self._grammar_bitmask - # Generate a batched bitmask for all structured output requests. - # When speculative decoding is enabled, we need to include multiple - # masks for each request, one for each possible bonus token position. - # These are stored inline in the tensor and unpacked by the gpu runner. - cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), - key=lambda x: x[1]) - - # Note that for thinking support, we will need to - # reset the relevant part of the bitmask for consequent - # request here. - bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( - self._full_mask) - - # NOTE: This outer loop can likely be parallelized to improve - # performance of bitmask generation for large batches. - for req_id, _ in ordered_seq: - request = requests[req_id] - structured_output_request = request.structured_output_request - - if TYPE_CHECKING: - assert structured_output_request is not None - assert structured_output_request.grammar is not None - apply_bitmask: bool = True - if self.reasoner is not None: - if structured_output_request.reasoning_ended is None: - structured_output_request.reasoning_ended = \ - self.reasoner.is_reasoning_end(request.prompt_token_ids) - apply_bitmask = structured_output_request.reasoning_ended - - state_advancements = 0 - req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] - for i, token in enumerate(req_tokens): - if apply_bitmask and not \ - structured_output_request.grammar.is_terminated(): - structured_output_request.grammar.fill_bitmask( - bitmask_tensor, cumulative_index) - if token is not None: - # In order to generate the correct bitmask for each - # position in the speculative sequence, we advance - # the FSM state for each speculative token and rollback - # to restore the previous state when we are finished. - assert structured_output_request.grammar.accept_tokens( - req_id, [token]) - state_advancements += 1 - cumulative_index += 1 - if state_advancements > 0: - structured_output_request.grammar.rollback(state_advancements) - - if cumulative_index < bitmask_tensor.shape[0]: - bitmask_tensor = bitmask_tensor[:cumulative_index] - - # After finishing with the xgrammar operations, we convert to - # np.ndarray, because that is much more efficient for serialization - # and deserialization when sending this to the GPU workers. - return bitmask_tensor.numpy() + else: + return self.backend.init_batch( + requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) def should_advance(self, request: Request) -> bool: if not request.use_structured_output: @@ -220,3 +235,93 @@ def should_advance(self, request: Request) -> bool: def clear_backend(self) -> None: if self.backend is not None: self.backend.destroy() + + def precompile(self, dummy_logits: torch.Tensor, **kwargs): + """ + Allow backend precompilation for the device + - Currently only used in the TPU model runner + + Args: + num_reqs_paddings (List[int]): A list of padding sizes for the + number of requests. + vocab_size (int): The size of the vocabulary. + device (torch.device): The device on which the model is running. + hidden_states_dtype (torch.dtype): The data type of the + hidden states. + """ + assert self.backend is not None + self.backend.precompile(dummy_logits, **kwargs) + + @staticmethod + def validate_request(params: SamplingParams, + vllm_config: VllmConfig) -> None: + """ + Validate the request for structured output. + This method checks the request for any errors or inconsistencies + + If one backend fails validation, we try the next one. + + The SamplingParams object is modified to set the backend and + backend_was_auto attributes based on the validation results. + + This needs to be a static method as it is called from the request + Processor which runs in a different process + + Args: + params (SamplingParams): The sampling parameters for the request. + + Raises: + ValueError: If the request contains an invalid backend or if the + request-level backend selection is not supported. + """ + if not params.guided_decoding or not vllm_config.decoding_config: + return + + engine_level_backend = vllm_config.decoding_config.backend + if params.guided_decoding.backend: + # Request-level backend selection is not supported in V1. + # The values may differ if `params` is reused and was set + # to a specific backend based on `auto` behavior in a previous + # request. We remember that it was set as a result of `auto` + # using the `_auto` option set on the backend in the params. + if (params.guided_decoding.backend != engine_level_backend + and not (engine_level_backend == "auto" + and params.guided_decoding.backend_was_auto)): + raise ValueError( + "Request-level structured output backend selection is no " + "longer supported. The request specified " + f"'{params.guided_decoding.backend}', but vLLM was " + f"initialised with '{engine_level_backend}'. This error " + "can be resolved by removing backend selection from the " + "request.") + else: + params.guided_decoding.backend = engine_level_backend + + # Request content validation + if engine_level_backend.startswith("xgrammar"): + # xgrammar with no fallback + validate_xgrammar_grammar(params) + elif engine_level_backend.startswith("guidance"): + # TODO: ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + validate_guidance_grammar(params, tokenizer=None) + else: + # NOTE: engine_level_backend must be "auto" here, because we have + # checked supported_backends above. + # "auto" is an opt-in to opinionated behavior where we try to + # choose a backend based on request contents. This is not the + # default as it is less predictable and subject to change + # between releases as feature support changes. + try: + validate_xgrammar_grammar(params) + params.guided_decoding.backend = "xgrammar" + except ValueError: + # The request either failed validation + # or includes some jsonschema feature(s) that + # are not supported in xgrammar. Fall back to guidance. + validate_guidance_grammar(params, tokenizer=None) + params.guided_decoding.backend = "guidance" + # Remember that this backend was set automatically + params.guided_decoding.backend_was_auto = True diff --git a/vllm/v1/structured_output/backend_bitmasking.py b/vllm/v1/structured_output/backend_bitmasking.py new file mode 100644 index 000000000000..7db86cd32001 --- /dev/null +++ b/vllm/v1/structured_output/backend_bitmasking.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import numpy as np +import numpy.typing as npt +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, StructuredOutputBatchMetaData, + StructuredOutputGrammar, StructuredOutputOptions) + +if TYPE_CHECKING: + + from vllm.reasoning import ReasoningParser + from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class BitmaskSOBatchMetaData(StructuredOutputBatchMetaData): + """ + This class is used to store the bitmask for structured output requests. + It is used to pass the bitmask to the GPU workers. + """ + + grammar_bitmask: torch.Tensor + + +class BitmaskStructuredOutputBackend(StructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + vocab_size: int, reasoner: ReasoningParser): + super().__init__(vllm_config, tokenizer, vocab_size, reasoner) + self._grammar_bitmask: Optional[torch.Tensor] = None + self._full_mask = torch.tensor(-1, dtype=torch.int32) + + def grammar_bitmask( + self, + requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + scheduled_spec_decode_tokens: dict[str, list[int]], + ) -> Optional[npt.NDArray[np.int32]]: + """ + Method used by XGrammar and Guidance to process and filter all logits + """ + + max_num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + max_num_spec_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + + if self._grammar_bitmask is None: + max_batch_size = self.vllm_config.scheduler_config.max_num_seqs + # Allocate a bitmask for each token needing to be checked: + # one for each speculative position, and one more for the + # bonus token / non-speculative token. + self._grammar_bitmask = \ + self.allocate_token_bitmask( + max_batch_size * (1 + max_num_spec_tokens)) + + bitmask_tensor = self._grammar_bitmask + # Generate a batched bitmask for all structured output requests. + # When speculative decoding is enabled, we need to include multiple + # masks for each request, one for each possible bonus token position. + # These are stored inline in the tensor and unpacked by the gpu runner. + cumulative_index = 0 + ordered_seq = sorted(structured_output_request_ids.items(), + key=lambda x: x[1]) + + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # request here. + bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( + self._full_mask) + + # NOTE: This outer loop can likely be parallelized to improve + # performance of bitmask generation for large batches. + for req_id, _ in ordered_seq: + request = requests[req_id] + structured_output_request = request.structured_output_request + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + assert isinstance(structured_output_request.grammar, + BitmaskGrammar) + + apply_bitmask: bool = True + if self.reasoner is not None: + if structured_output_request.reasoning_ended is None: + structured_output_request.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + apply_bitmask = structured_output_request.reasoning_ended + + state_advancements = 0 + req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] + for i, token in enumerate(req_tokens): + if apply_bitmask and not \ + structured_output_request.grammar.is_terminated(): + structured_output_request.grammar.fill_bitmask( + bitmask_tensor, cumulative_index) + if token is not None: + # In order to generate the correct bitmask for each + # position in the speculative sequence, we advance + # the FSM state for each speculative token and rollback + # to restore the previous state when we are finished. + assert structured_output_request.grammar.accept_tokens( + req_id, [token]) + state_advancements += 1 + cumulative_index += 1 + if state_advancements > 0: + structured_output_request.grammar.rollback(state_advancements) + + if cumulative_index < bitmask_tensor.shape[0]: + bitmask_tensor = bitmask_tensor[:cumulative_index] + + # After finishing with the xgrammar operations, we convert to + # np.ndarray, because that is much more efficient for serialization + # and deserialization when sending this to the GPU workers. + return bitmask_tensor.numpy() + + def init_batch( + self, requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + scheduled_spec_decode_tokens: dict[str, list[int]] + ) -> StructuredOutputBatchMetaData: + bitmask = self.grammar_bitmask(requests, structured_output_request_ids, + scheduled_spec_decode_tokens) + return BitmaskSOBatchMetaData(structured_output_request_ids, bitmask) + + @abstractmethod + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + """ + Allocates a token bitmask for the specified maximum number of sequences. + + Args: + max_num_seqs (int): The maximum number of sequences for which + to allocate the bitmask. + """ + + @abstractmethod + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + """ + Compiles a grammar specification into a structured output grammar. + + Args: + request_type (StructuredOutputOptions): The type of structured + output request. + grammar_spec (str): The grammar specification to compile. + + Returns: + StructuredOutputGrammar: The compiled structured output grammar. + """ + + @abstractmethod + def destroy(self): + pass + + +class BitmaskGrammar(StructuredOutputGrammar): + + @abstractmethod + def is_terminated(self) -> bool: + """ + Checks whether the structured output process has terminated. + + Returns: + bool: True if the process is terminated, False otherwise. + """ + + @abstractmethod + def reset(self): + """ + Resets the state of the structured output grammar. + """ + + @abstractmethod + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + pass diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 02e7fc33f517..877a453d6f79 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -11,11 +11,13 @@ import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, +from vllm.v1.structured_output.backend_bitmasking import ( + BitmaskGrammar, BitmaskStructuredOutputBackend) +from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, StructuredOutputOptions) from vllm.v1.structured_output.request import get_structured_output_key @@ -23,6 +25,10 @@ import llguidance import llguidance.hf as llguidance_hf import llguidance.torch as llguidance_torch + + from vllm.reasoning import ReasoningParser + from vllm.transformers_utils.tokenizer import AnyTokenizer + else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") @@ -55,10 +61,13 @@ def process_for_additional_properties( return guide_json_obj -@dataclass -class GuidanceBackend(StructuredOutputBackend): +class GuidanceBackend(BitmaskStructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + vocab_size: int, reasoner: ReasoningParser): + super().__init__(vllm_config, tokenizer, vocab_size, reasoner) + self.vocab_size = self.vllm_config.model_config.get_vocab_size() - def __post_init__(self): self.disable_any_whitespace = \ self.vllm_config.decoding_config.disable_any_whitespace self.disable_additional_properties = \ @@ -97,7 +106,7 @@ def destroy(self): @dataclass -class GuidanceGrammar(StructuredOutputGrammar): +class GuidanceGrammar(BitmaskGrammar): ll_matcher: llguidance.LLMatcher ll_tokenizer: llguidance.LLTokenizer vocab_size: int diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index d500783aa4b3..c94e09e3f0bc 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -12,7 +12,9 @@ import torch from vllm.config import VllmConfig + from vllm.reasoning import ReasoningParser from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.v1.request import Request class StructuredOutputOptions(enum.Enum): @@ -68,39 +70,25 @@ def rollback(self, num_tokens: int) -> None: num_tokens (int): The number of tokens to roll back. """ - @abstractmethod - def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: - """ - Fills the bitmask for a specific batch index. - - Args: - bitmask (torch.Tensor): The bitmask to fill - batch_index (int): The index in the bitmask to fill - """ - - @abstractmethod - def is_terminated(self) -> bool: - """ - Checks whether the structured output process has terminated. - - Returns: - bool: True if the process is terminated, False otherwise. - """ - @abstractmethod - def reset(self): - """ - Resets the state of the structured output grammar. - """ +@dataclass +class StructuredOutputBatchMetaData: + """Extend this class to add any additional metadata to the batch + """ + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] -@dataclass class StructuredOutputBackend(ABC): """Engine-level backend for structured output requests.""" - vllm_config: VllmConfig - tokenizer: AnyTokenizer - vocab_size: int + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + vocab_size: int, reasoner: ReasoningParser): + self.vllm_config = vllm_config + self.tokenizer = tokenizer + self.vocab_size = vocab_size + self.reasoner = reasoner @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, @@ -117,18 +105,18 @@ def compile_grammar(self, request_type: StructuredOutputOptions, StructuredOutputGrammar: The compiled structured output grammar. """ - @abstractmethod - def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: - """ - Allocates a token bitmask for the specified maximum number of sequences. - - Args: - max_num_seqs (int): The maximum number of sequences for which - to allocate the bitmask. - """ + def init_batch( + self, requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + scheduled_spec_decode_tokens: dict[str, list[int]] + ) -> StructuredOutputBatchMetaData: + return StructuredOutputBatchMetaData(structured_output_request_ids) @abstractmethod def destroy(self): """ Backend-specific cleanup. """ + + def precompile(self, dummy_logits: torch.Tensor, **kwargs): + return diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 88544565e544..e91a0bca4167 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -10,12 +8,14 @@ import torch import vllm.envs +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, +from vllm.v1.structured_output.backend_bitmasking import ( + BitmaskGrammar, BitmaskStructuredOutputBackend) +from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, StructuredOutputOptions) from vllm.v1.structured_output.utils import (choice_as_grammar, convert_lark_to_ebnf, @@ -23,19 +23,26 @@ if TYPE_CHECKING: import xgrammar as xgr + + from vllm.reasoning import ReasoningParser + from vllm.transformers_utils.tokenizer import AnyTokenizer + else: xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) -@dataclass -class XgrammarBackend(StructuredOutputBackend): +class XgrammarBackend(BitmaskStructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig, tokenizer: "AnyTokenizer", + vocab_size: int, reasoner: "ReasoningParser"): + super().__init__(vllm_config, tokenizer, vocab_size, reasoner) - def __post_init__(self): self.disable_any_whitespace = \ self.vllm_config.decoding_config.disable_any_whitespace + self.vocab_size = vllm_config.model_config.get_vocab_size() if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 @@ -133,7 +140,7 @@ def destroy(self): @dataclass -class XgrammarGrammar(StructuredOutputGrammar): +class XgrammarGrammar(BitmaskGrammar): # NOTE: This would be a generic-enough class for # supporting different backends, in the future. # For now, just xgrammar. diff --git a/vllm/v1/structured_output/worker_backend.py b/vllm/v1/structured_output/worker_backend.py new file mode 100644 index 000000000000..8ada0816bf74 --- /dev/null +++ b/vllm/v1/structured_output/worker_backend.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.gpu_input_batch import InputBatch + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class StructuredOutputWorkerBackend: + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def filter_logits(self, input_batch: InputBatch, device: torch.device, + scheduler_output: SchedulerOutput, logits: torch.Tensor, + sample_hidden_states: torch.Tensor, **kwargs) -> None: + """ + Filters the logits produced by the model's forward pass. + + Called in v1.worker.XXXModelRunner.execute_model immediately + after the model forward pass. + + Args: + input_batch (InputBatch): The batch of input data being processed. + device (torch.device): The device on which the computation is + performed. + scheduler_output (SchedulerOutput): The output from the scheduler + containing additional information for processing. + logits (torch.Tensor): The raw logits output from the model's + forward pass. + sample_hidden_states (torch.Tensor): The hidden states of the + samples from the model's forward pass. + """ + pass + + def precompile(self, dummy_logits: torch.Tensor, **kwargs): + return + + @abstractmethod + def supported_backends(self) -> list[str]: + """ + Specify the StructuredOutputBackend's the worker Supports + """ + pass diff --git a/vllm/v1/structured_output/worker_backend_bitmasking_gpu.py b/vllm/v1/structured_output/worker_backend_bitmasking_gpu.py new file mode 100644 index 000000000000..6b0b8aff513b --- /dev/null +++ b/vllm/v1/structured_output/worker_backend_bitmasking_gpu.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, cast + +import numpy as np +import torch +import xgrammar as xgr + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.structured_output.backend_bitmasking import BitmaskSOBatchMetaData +from vllm.v1.structured_output.worker_backend import ( + StructuredOutputWorkerBackend) +from vllm.v1.worker.gpu_input_batch import InputBatch + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class BitmaskGPUBackend(StructuredOutputWorkerBackend): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self._grammar_bitmask: Optional[torch.Tensor] = None + + @staticmethod + def apply_grammar_bitmask( + input_batch: InputBatch, + device: torch.device, + scheduler_output: SchedulerOutput, + logits: torch.Tensor, + ): + meta = cast(BitmaskSOBatchMetaData, + scheduler_output.structured_output_meta) + if meta.grammar_bitmask is None: + return + grammar_bitmask = meta.grammar_bitmask + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in meta.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.zeros_like(grammar_bitmask, + shape=(logits.shape[0], + grammar_bitmask.shape[1])) + cumulative_index = 0 + seq = sorted(meta.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask) + + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(device, non_blocking=True), + indices=out_indices, + ) + + def filter_logits( + self, + input_batch: InputBatch, + device: torch.device, + scheduler_output: SchedulerOutput, + logits: torch.Tensor, + sample_hidden_states: torch.Tensor, + **kwargs, + ) -> None: + BitmaskGPUBackend.apply_grammar_bitmask( + input_batch, + device, + scheduler_output, + logits, + ) + + def supported_backends(self) -> list[str]: + return ["xgrammar", "guidance"] diff --git a/vllm/v1/structured_output/worker_backend_bitmasking_tpu.py b/vllm/v1/structured_output/worker_backend_bitmasking_tpu.py new file mode 100644 index 000000000000..32d3f351a2ab --- /dev/null +++ b/vllm/v1/structured_output/worker_backend_bitmasking_tpu.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, cast + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import cdiv, is_pin_memory_available +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.structured_output.backend_bitmasking import BitmaskSOBatchMetaData +from vllm.v1.structured_output.worker_backend import ( + StructuredOutputWorkerBackend) +from vllm.v1.worker.gpu_input_batch import InputBatch + +if TYPE_CHECKING: + + pass + +logger = init_logger(__name__) + + +class BitmaskTPUBackend(StructuredOutputWorkerBackend): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self._grammar_bitmask: Optional[torch.Tensor] = None + self.max_num_reqs: Optional[int] = None + self.tpu_vocab_size = self.vllm_config.model_config.get_vocab_size() + self.pin_memory = is_pin_memory_available() + self.require_structured_out_cpu = torch.Tensor() + self.structured_decode_arange = torch.Tensor() + self.grammar_bitmask_cpu = torch.Tensor() + + def init_tensors(self, max_num_reqs: int): + self.max_num_reqs = max_num_reqs + self.require_structured_out_cpu = torch.zeros( + (self.max_num_reqs), + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + self.structured_decode_arange = torch.arange( + 0, 32, device="cpu", pin_memory=self.pin_memory) + self.grammar_bitmask_cpu = torch.zeros( + (self.max_num_reqs, cdiv(self.tpu_vocab_size, 32)), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + + def filter_logits( + self, + input_batch: InputBatch, + device: torch.device, + scheduler_output: SchedulerOutput, + logits: torch.Tensor, + sample_hidden_states: torch.Tensor, + **kwargs, + ) -> None: + if self.max_num_reqs is None: + assert "max_num_reqs" in kwargs, "max_num_reqs must be provided" + max_num_reqs = kwargs.get("max_num_reqs") + assert isinstance(max_num_reqs, int), \ + "max_num_reqs must be an integer" + self.init_tensors(max_num_reqs) + + require_struct_decoding, grammar_bitmask_padded, arange = \ + self.prepare_structured_decoding_input(logits, + scheduler_output, input_batch) + self.structured_decode(require_struct_decoding, grammar_bitmask_padded, + logits, arange) + + def prepare_structured_decoding_input( + self, logits: torch.Tensor, scheduler_output: SchedulerOutput, + input_batch: InputBatch + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + meta = cast(BitmaskSOBatchMetaData, + scheduler_output.structured_output_meta) + grammar_bitmask = meta.grammar_bitmask + assert grammar_bitmask is not None + num_reqs, _ = logits.shape + + # Reset pre-allocated tensors + self.grammar_bitmask_cpu.zero_() + self.require_structured_out_cpu.zero_() + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the tpu runner is + # ordering the requests in the batch. We need to match the order of + # bitmask with the order of requests + struct_out_indices: list[int] = [] + mask_indices: list[int] = [] + assert scheduler_output.structured_output_meta is not None + for req_id in input_batch.req_ids: + mask_index = scheduler_output.structured_output_meta.\ + structured_output_request_ids.get(req_id) + if mask_index is None: + continue + batch_index = input_batch.req_id_to_index[req_id] + struct_out_indices.append(batch_index) + mask_indices.append(mask_index) + self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( + grammar_bitmask[mask_indices]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) + self.require_structured_out_cpu[struct_out_indices] = True + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ + self.structured_decode_arange.to(logits.device) + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def structured_decode(self, require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, logits: torch.Tensor, + arange: torch.Tensor): + """Applies structured decoding by modifying logits in-place + where required. + + Args: + require_struct_decoding: [B] boolean tensor indicating + which batch items need structured decoding + grammar_bitmask: [B, vocab_size//32] packed bit tensor + containing valid token masks + logits: [B, vocab_size] tensor to modify in-place + arange: [32] tensor for bit unpacking, contains values [0..31] + """ + assert (logits.shape[0] == grammar_bitmask.shape[0]) + + # Unpack bits for all batch items at once + unpacked_bitmask = ( + torch.bitwise_right_shift( + grammar_bitmask[:, :, None], # [B, vocab_size//32, 1] + arange[None, None, :] # [1, 1, 32] + ) & 1) == 0 # Result: [B, vocab_size//32, 32] + + unpacked_bitmask = unpacked_bitmask.reshape( + logits.shape[0], -1)[:, :self.tpu_vocab_size] # [B, vocab_size] + + # Only apply mask where require_struct_decoding is True + mask_to_apply = unpacked_bitmask & \ + require_struct_decoding[:,None] # [B, vocab_size] + + # Apply mask in-place + logits.masked_fill_(mask_to_apply, -float("inf")) + + def precompile(self, dummy_logits: torch.Tensor, **kwargs): + if self.max_num_reqs is None: + assert "max_num_reqs" in kwargs, "max_num_reqs must be provided" + max_num_reqs = kwargs.get("max_num_reqs") + assert isinstance(max_num_reqs, int), \ + "max_num_reqs must be an integer" + self.init_tensors(max_num_reqs) + + num_reqs = dummy_logits.shape[0] + dummy_require_struct_decoding = \ + self.require_structured_out_cpu[:num_reqs].to(dummy_logits.device) + dummy_grammar_bitmask = \ + self.grammar_bitmask_cpu[:num_reqs].to(dummy_logits.device) + # The first dimension of the dummy logits and 2 dummy tensors above + # cannot be mark_dynamic because some operations in structured_decode + # require them to be static. + arange = self.structured_decode_arange.to(dummy_logits.device) + self.structured_decode(dummy_require_struct_decoding, + dummy_grammar_bitmask, dummy_logits, arange) + + def supported_backends(self) -> list[str]: + return ["xgrammar", "guidance"] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a90c294a9749..44ceedc3866d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -53,6 +53,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported +from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -74,11 +75,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - ): + def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -89,6 +86,9 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.structured_output_worker = StructuredOutputManager.\ + get_worker_backend( + vllm_config) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( @@ -1040,65 +1040,6 @@ def _gather_mm_embeddings( def get_model(self) -> nn.Module: return self.model - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.zeros_like(grammar_bitmask, - shape=(logits.shape[0], - grammar_bitmask.shape[1])) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask) - - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices, - ) - def sync_and_slice_intermediate_tensors( self, num_tokens: int, intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: @@ -1290,9 +1231,10 @@ def execute_model( assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + if scheduler_output.structured_output_meta is not None: + self.structured_output_worker.filter_logits( + self.input_batch, self.device, scheduler_output, logits, + sample_hidden_states) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 843bc36953b5..d173e00c9c17 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -41,6 +41,7 @@ ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler +from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -116,6 +117,9 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.device_config = vllm_config.device_config + self.structured_output_worker = StructuredOutputManager.\ + get_worker_backend( + vllm_config) model_config = self.model_config cache_config = self.cache_config @@ -251,20 +255,6 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - # tensors for structured decoding - self.grammar_bitmask_cpu = torch.zeros( - (self.max_num_reqs, cdiv(self.vocab_size, 32)), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.require_structured_out_cpu = torch.zeros( - (self.max_num_reqs, 1), - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory) - self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory) - # Get maximum number of mm items per modality (batch size). self.max_num_mm_items_by_modality = dict() if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 @@ -864,12 +854,14 @@ def execute_model( logits = self.compute_logits(hidden_states) tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, padded_num_reqs, self.device) - if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) + if scheduler_output.structured_output_meta is not None: + self.structured_output_worker.filter_logits( + self.input_batch, + self.device, + scheduler_output, + logits, + hidden_states, + max_num_reqs=self.max_num_reqs) selected_token_ids = self.sample_from_logits_func( logits, tpu_sampling_metadata) # NOTE (NickLucche) Use the original logits (before any penalties or @@ -1205,16 +1197,8 @@ def _precompile_structured_decoding(self) -> None: dummy_logits = torch.zeros((num_reqs, self.vocab_size), device=self.device, dtype=self._hidden_states_dtype) - dummy_require_struct_decoding = \ - self.require_structured_out_cpu[:num_reqs].to(self.device) - dummy_grammar_bitmask = \ - self.grammar_bitmask_cpu[:num_reqs].to(self.device) - # The first dimension of the above 3 dummy tensors cannot be - # mark_dynamic because some operations in structured_decode require - # them to be static. - arange = self.structured_decode_arange.to(self.device) - self.structured_decode(dummy_require_struct_decoding, - dummy_grammar_bitmask, dummy_logits, arange) + self.structured_output_worker.precompile( + dummy_logits, max_num_reqs=self.max_num_reqs) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1477,71 +1461,12 @@ def gather_logprobs(self, logits: torch.Tensor, self.model_config.max_logprobs, token_ids=sampled_tokens.squeeze(-1)) - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode(self, require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, logits: torch.Tensor, - arange: torch.Tensor) -> torch.Tensor: - return torch.where( - require_struct_decoding, - self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits) - - def apply_grammar_bitmask(self, logits: torch.Tensor, - grammar_bitmask: torch.Tensor, - arange: torch.Tensor): - assert (logits.shape[0] == grammar_bitmask.shape[0]) - logits_cloned = logits.clone() - for i in range(logits.shape[0]): - unpacked_bitmask = (torch.bitwise_right_shift( - grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] - logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf")) - return logits_cloned - def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) - def prepare_structured_decoding_input( - self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = scheduler_output.grammar_bitmask - assert grammar_bitmask is not None - num_reqs, _ = logits.shape - - # Reset pre-allocated tensors - self.grammar_bitmask_cpu.zero_() - self.require_structured_out_cpu.zero_() - - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: - continue - batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True - return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ - self.structured_decode_arange.to(logits.device) - def _get_mm_dummy_batch(self, modality: str, batch_size: int) -> BatchedTensorInputs: # Dummy data for pre-compiling multimodal models.