Skip to content

Add support for token_type_ids #19988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5dee54d
Add support for encoder embedding models
maxdebayser Jun 23, 2025
7eb9d28
Fix CUDA graphs for BERT models
maxdebayser Jul 1, 2025
67691e0
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 1, 2025
d3099a9
Fix cuda graph initialization of token type ids
maxdebayser Jul 1, 2025
613ff3b
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
20c41e4
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
ba86026
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 8, 2025
b4f5ead
Fix missing args
maxdebayser Jul 9, 2025
c4060d1
relax assertion
maxdebayser Jul 9, 2025
01d2a65
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 9, 2025
80930d8
fix missing arg
maxdebayser Jul 9, 2025
d881f0a
fix missing arg
maxdebayser Jul 10, 2025
90a25d0
remove model from unsupported list
maxdebayser Jul 10, 2025
6686550
fix missing arg
maxdebayser Jul 10, 2025
cc76777
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 10, 2025
136c9b3
fix tests
maxdebayser Jul 10, 2025
b232491
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 14, 2025
cf5e6b8
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
e19c738
fix tests
maxdebayser Jul 16, 2025
e255f30
fix tests
maxdebayser Jul 16, 2025
ee5950c
add missing arg
maxdebayser Jul 16, 2025
78a2e57
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
a5cfc84
add missing arg
maxdebayser Jul 16, 2025
63fd783
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
f58692c
Merge branch 'main' into v1_embeddings_full
maxdebayser Jul 20, 2025
eea55fb
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 25, 2025
f2d8e18
Merge branch 'v1_embeddings_full' of github.com:maxdebayser/vllm into…
maxdebayser Jul 25, 2025
12ae080
revert attn changes to simplify merge
maxdebayser Jul 28, 2025
f29da32
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"The capital of Germany is Berlin.",
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


DTYPE = "half"


Expand Down
21 changes: 11 additions & 10 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ def _run_incremental_decode(tokenizer,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
request = EngineCoreRequest(request_id="",
prompt_token_ids=prompt_token_ids,
token_type_ids=None,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None)

Expand Down
1 change: 1 addition & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def make_request(request_id,
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
token_type_ids=None,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def make_request(request_id,
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
token_type_ids=None,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,7 @@ def create_requests_with_priority(
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
token_type_ids=None,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
Expand Down Expand Up @@ -1816,6 +1817,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
token_type_ids=None,
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def create_requests(
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
token_type_ids=None,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
token_type_ids=None,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def make_request(
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids,
token_type_ids=None,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_fast_incdec_prefix_err.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
None,
None,
None,
None,
params,
None,
None,
Expand Down
5 changes: 5 additions & 0 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
token_type_ids=None,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
Expand Down Expand Up @@ -401,6 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
token_type_ids=None,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
Expand Down Expand Up @@ -566,6 +568,7 @@ def test_stop_token(include_stop_str_in_output: bool,
request = EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_tokens,
token_type_ids=None,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
Expand Down Expand Up @@ -665,6 +668,7 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest(
request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
token_type_ids=None,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
Expand Down Expand Up @@ -781,6 +785,7 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
token_type_ids=None,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def create_request(
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
token_type_ids=None,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
token_type_ids=None,
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
Expand Down
4 changes: 4 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,17 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
token_type_ids = [
np.random.randint(0, 2) for _ in range(len(prompt_token_ids))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[],
Expand Down
1 change: 1 addition & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
token_type_ids=None,
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix


Expand Down Expand Up @@ -508,8 +508,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
})


class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.sequence import IntermediateTensors

from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding


class RobertaEmbedding(nn.Module):
Expand Down Expand Up @@ -153,8 +153,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loader.load_weights(weights_list, mapper=mapper)


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class NewRequestData:

req_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
Expand All @@ -42,6 +43,7 @@ def from_request(
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
token_type_ids=request.token_type_ids,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class EngineCoreRequest(

request_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def process_inputs(
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
token_type_ids=decoder_inputs.get("token_type_ids"),
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
self,
request_id: str,
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
"sampling_params and pooling_params can't both be unset")

self.prompt_token_ids = prompt_token_ids
self.token_type_ids = token_type_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
Expand Down Expand Up @@ -119,6 +121,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
token_type_ids=request.token_type_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
Expand Down
31 changes: 30 additions & 1 deletion vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
# Datastructures defining an input batch

from dataclasses import dataclass
from typing import Optional, cast
Expand Down Expand Up @@ -29,6 +29,7 @@ class CachedRequestState:

req_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
Expand Down Expand Up @@ -93,6 +94,8 @@ def __init__(
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.token_type_ids_cpu_tensor = None
self._token_type_ids_cpu = None
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
Expand Down Expand Up @@ -240,6 +243,22 @@ def __init__(

self.pooling_params: dict[str, PoolingParams] = {}

@property
def token_type_ids_cpu(self) -> np.ndarray:
if self._token_type_ids_cpu is None:
self.token_type_ids_cpu_tensor = torch.zeros(
self.token_ids_cpu_tensor.shape,
device="cpu",
dtype=torch.int8,
pin_memory=False,
)
self._token_type_ids_cpu = cast(
torch.Tensor, self.token_type_ids_cpu_tensor).numpy()
return self._token_type_ids_cpu

def has_token_types(self) -> bool:
return self._token_type_ids_cpu is not None

@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
Expand Down Expand Up @@ -284,6 +303,9 @@ def add_request(
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
if request.token_type_ids is not None:
self.token_type_ids_cpu[
req_index, :num_prompt_tokens] = request.token_type_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
Expand Down Expand Up @@ -475,6 +497,10 @@ def swap_states(self, i1: int, i2: int) -> None:
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
if self.has_token_types():
tmp2 = self.token_type_ids_cpu[i1, ...].copy()
self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...]
self.token_type_ids_cpu[i2, ...] = tmp2

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
Expand Down Expand Up @@ -545,6 +571,9 @@ def condense(self) -> None:
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
if self.has_token_types():
self.token_type_ids_cpu[empty_index, :num_tokens] = \
self.token_type_ids_cpu[last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
Expand Down
Loading