diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index ef9d5530cde..6b5ff706814 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index ccafc884612..26557be2c57 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -61,16 +61,17 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + token_type_ids=None, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e9c6f1f95cd..4f26db559fc 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -40,6 +40,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 085616303d8..1f02ff7cfc7 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -35,6 +35,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c719d1975bb..949b644cbd5 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1330,6 +1330,7 @@ def create_requests_with_priority( request = Request( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, @@ -1816,6 +1817,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): request = Request( request_id="0", prompt_token_ids=[0, 1], + token_type_ids=None, multi_modal_inputs=None, multi_modal_hashes=None, multi_modal_placeholders=None, diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 02ca4498db1..74fce374988 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -138,6 +138,7 @@ def create_requests( request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bbdc73e9608..f5ddbeeb4fd 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -35,6 +35,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 2ac6dc796bd..0e64230d548 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -51,6 +51,7 @@ def make_request( return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f028b4ab1d7..61a4126ff60 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -31,6 +31,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): None, None, None, + None, params, None, None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 949ab764e2e..c59439ed9c0 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -52,6 +52,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -401,6 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -566,6 +568,7 @@ def test_stop_token(include_stop_str_in_output: bool, request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -665,6 +668,7 @@ def test_stop_string(include_stop_str_in_output: bool, EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -781,6 +785,7 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 480a7074cdf..f435f531036 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -152,6 +152,7 @@ def create_request( req = Request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=None, diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 215be09bf5a..d90f129d743 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -64,6 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 943a13debad..fb8ad382ead 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -194,6 +194,9 @@ def _construct_cached_request_state(req_id_suffix: int): np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) ] + token_type_ids = [ + np.random.randint(0, 2) for _ in range(len(prompt_token_ids)) + ] output_token_ids = [ np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) @@ -201,6 +204,7 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, sampling_params=_create_sampling_params(), pooling_params=None, mm_inputs=[], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e14fbe1e47e..2c7c44152f1 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -120,6 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index adef350931f..7ed08b5a86a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1269,34 +1269,18 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if model_config.is_multimodal_model: - for q, d in input_pairs: - _, engine_prompt = get_score_prompt( - model_config=model_config, - data_1=q, - data_2=d, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - ) + model_config = self.llm_engine.model_config - parsed_prompts.append(engine_prompt) - else: - for q, t in input_pairs: - if model_config.use_pad_token: - # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer( - text=q, # type: ignore[arg-type] - text_pair=t, # type: ignore[arg-type] - **tokenization_kwargs) - else: - # `llm as reranker` models defaults to not using pad_token. - prompt_inputs = tokenizer( - text=q + t, # type: ignore[operator] - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + for q, d in input_pairs: + _, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=q, + data_2=d, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) + + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( prompts=parsed_prompts, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4da2094147c..66f2af23a0d 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -188,56 +188,19 @@ async def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if self.model_config.is_multimodal_model: - - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) - - preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) - - for full_prompt, engine_prompt in preprocessed_prompts: - request_prompts.append(full_prompt) - engine_prompts.append(engine_prompt) - - else: - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - use_pad_token = self.model_config.use_pad_token - - if use_pad_token: - # cross_encoder models defaults to using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1, # type: ignore[arg-type] - text_pair=t2, # type: ignore[arg-type] - **tokenization_kwargs) for t1, t2 in input_pairs)) - else: - # `llm as reranker` models defaults to not using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1 + # type: ignore[operator] - t2, - **tokenization_kwargs) for t1, t2 in input_pairs)) - - for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if (tokenizer.sep_token - and use_pad_token) else '' - request_prompt = f"{t1}{sep_token}{t2}" - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + preprocess_async = make_async(self._preprocess_score, + executor=self._tokenizer_executor) + + preprocessed_prompts = await asyncio.gather( + *(preprocess_async(request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2) for t1, t2 in input_pairs)) + + for full_prompt, engine_prompt in preprocessed_prompts: + request_prompts.append(full_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index f3f042355c9..7d420c19b87 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -184,13 +184,28 @@ def get_score_prompt( model_config, tokenizer, ) + from vllm.model_executor.model_loader import get_model_cls - full_prompt = apply_score_template(model_config, prompt_1, prompt_2) - - prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + model = get_model_cls(model_config) + if supports_score_template(model): + full_prompt = apply_score_template(model_config, prompt_1, prompt_2) + prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + elif model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=prompt_1, + text_pair=prompt_2, + **tokenization_kwargs) + full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) + else: + # `llm as reranker` models defaults to not using pad_token. + full_prompt = prompt_1 + prompt_2 + prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: + engine_prompt["token_type_ids"] = token_type_ids + post_process_tokens(model_config, engine_prompt) if mm_data is not None: diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 504621c8abd..cc9d7bf8147 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,7 +28,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -508,8 +508,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: }) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 77e072c7927..feb549d44ea 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -20,7 +20,7 @@ from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -153,8 +153,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loader.load_weights(weights_list, mapper=mapper) -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index d34f3932780..e2ebef46522 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -24,6 +24,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -42,6 +43,7 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 79dc80d8fc5..81c746dd347 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,6 +49,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 0f2f404a130..7bbe632bca6 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -339,6 +339,7 @@ def process_inputs( return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], + token_type_ids=decoder_inputs.get("token_type_ids"), mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92eb..ce49f5054ef 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -24,6 +24,7 @@ def __init__( self, request_id: str, prompt_token_ids: list[int], + token_type_ids: Optional[list[int]], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], @@ -74,6 +75,7 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids + self.token_type_ids = token_type_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() @@ -119,6 +121,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c63041600f3..a2a0df7c21f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch +# Datastructures defining an input batch from dataclasses import dataclass from typing import Optional, cast @@ -29,6 +29,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -93,6 +94,8 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_type_ids_cpu_tensor = None + self._token_type_ids_cpu = None self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -240,6 +243,22 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + @property + def token_type_ids_cpu(self) -> np.ndarray: + if self._token_type_ids_cpu is None: + self.token_type_ids_cpu_tensor = torch.zeros( + self.token_ids_cpu_tensor.shape, + device="cpu", + dtype=torch.int8, + pin_memory=False, + ) + self._token_type_ids_cpu = cast( + torch.Tensor, self.token_type_ids_cpu_tensor).numpy() + return self._token_type_ids_cpu + + def has_token_types(self) -> bool: + return self._token_type_ids_cpu is not None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -284,6 +303,9 @@ def add_request( self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids + if request.token_type_ids is not None: + self.token_type_ids_cpu[ + req_index, :num_prompt_tokens] = request.token_type_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, @@ -475,6 +497,10 @@ def swap_states(self, i1: int, i2: int) -> None: tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + if self.has_token_types(): + tmp2 = self.token_type_ids_cpu[i1, ...].copy() + self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...] + self.token_type_ids_cpu[i2, ...] = tmp2 swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -545,6 +571,9 @@ def condense(self) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + if self.has_token_types(): + self.token_type_ids_cpu[empty_index, :num_tokens] = \ + self.token_type_ids_cpu[last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 84ad582c9c9..8fcaf1ce178 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import inspect import time from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast @@ -41,6 +42,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) @@ -252,7 +255,8 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - + self.token_type_ids: Optional[torch.Tensor] = None + self.supports_token_type_ids: bool = False # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -321,6 +325,19 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + def get_token_type_ids(self) -> torch.Tensor: + if self.token_type_ids is None: + self.token_type_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + return self.token_type_ids + + def _maybe_add_model_args(self, num_tokens: int, model_kwargs: dict[str, + Any]): + if self.supports_token_type_ids: + model_kwargs["token_type_ids"] =\ + self.get_token_type_ids()[:num_tokens] + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -436,6 +453,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -691,6 +709,13 @@ def _prepare_inputs( 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) + if self.input_batch.token_type_ids_cpu_tensor is not None: + token_type_ids = torch.index_select( + self.input_batch.token_type_ids_cpu_tensor.flatten(), 0, + torch.from_numpy(token_indices)) + # Copy the tensors to the GPU. + self.get_token_type_ids()[:total_num_scheduled_tokens]\ + .copy_(token_type_ids, non_blocking=True) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -1464,13 +1489,16 @@ def execute_model( else: mm_embeds = [] + model_kwargs: dict[str, Any] = {} + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_model_args(num_scheduled_tokens, model_kwargs) - model_kwargs = self._init_model_kwargs_for_multimodal_model( + model_mm_kwargs = self._init_model_kwargs_for_multimodal_model( scheduler_output=scheduler_output) inputs_embeds = self.model.get_input_embeddings( input_ids=input_ids, @@ -1487,8 +1515,9 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + self._maybe_add_model_args(num_input_tokens, model_kwargs) inputs_embeds = None - model_kwargs = {} + model_mm_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] else: @@ -1522,9 +1551,10 @@ def execute_model( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **MultiModalKwargs.as_kwargs( - model_kwargs, + model_mm_kwargs, device=self.device, ), + **model_kwargs, ) self.maybe_wait_for_kv_save() @@ -1858,6 +1888,14 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) + def _get_tokenizer(self) -> AnyTokenizer: + tokenizer_group = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + lora_config=self.lora_config) + + return tokenizer_group.get_lora_tokenizer() + def load_model(self, eep_scale_up: bool = False) -> None: """ Args: @@ -1941,6 +1979,26 @@ def load_model(self, eep_scale_up: bool = False) -> None: fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) + model_supports_token_type_ids = 'token_type_ids' in \ + inspect.getfullargspec(self.model.forward).args + + if not self.model_config.skip_tokenizer_init: + tokenizer = self._get_tokenizer() + if not isinstance(tokenizer, MistralTokenizer): + tok_output = tokenizer(text="foo") + if "token_type_ids" in tok_output: + if not model_supports_token_type_ids: + logger.warning( + "Tokenizer returns token_type_ids but " + "but model forward() doesn't support that " + "argument") + else: + self.supports_token_type_ids = True + + if self.supports_token_type_ids: + # pre-allocate tensor + self.get_token_type_ids() + def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded." @@ -2167,15 +2225,17 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} + self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: - model_kwargs = self._init_model_kwargs_for_multimodal_model( + model_mm_kwargs = self._init_model_kwargs_for_multimodal_model( num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None - model_kwargs = {} + model_mm_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] @@ -2206,9 +2266,10 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **MultiModalKwargs.as_kwargs( - model_kwargs, + model_mm_kwargs, device=self.device, ), + **model_kwargs, ) if self.use_aux_hidden_state_outputs: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 59cbb015057..fe8b99a2515 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -423,6 +423,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params,