diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index aff3498567d2..797353e4f7a8 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -11,11 +11,13 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors class MyGemma2Embedding(nn.Module): + + is_pooling_model = True + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -24,7 +26,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Gemma2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( vllm_config.model_config.pooler_config, pooling_type=PoolingType.LAST, normalize=True, @@ -54,13 +56,6 @@ def forward( # Return all-zero embeddings return torch.zeros_like(hidden_states) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 16cb5b75032c..a421ed1fc327 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:embedding-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:embedding-pooling-params] - # --8<-- [start:embedding-extra-params] add_special_tokens: bool = Field( default=True, @@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:chat-embedding-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:chat-embedding-pooling-params] - # --8<-- [start:chat-embedding-extra-params] add_special_tokens: bool = Field( default=False, @@ -1323,8 +1314,7 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel): text_2: Union[list[str], str, ScoreMultiModalParam] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:score-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:score-pooling-params] - # --8<-- [start:score-extra-params] mm_processor_kwargs: Optional[dict[str, Any]] = Field( @@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams(use_cross_encoder=use_cross_encoder) class RerankRequest(OpenAIBaseModel): @@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel): top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:rerank-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:rerank-pooling-params] - # --8<-- [start:rerank-extra-params] mm_processor_kwargs: Optional[dict[str, Any]] = Field( @@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams(use_cross_encoder=use_cross_encoder) class RerankDocument(BaseModel): @@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[int] = None user: Optional[str] = None - # --8<-- [start:classification-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:classification-pooling-params] - # --8<-- [start:classification-extra-params] priority: int = Field( default=0, @@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + return PoolingParams() class ClassificationData(OpenAIBaseModel): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index b378a3db0322..74916492f574 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -3,22 +3,25 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import IntEnum -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Literal, Optional, TypeVar, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig +from typing_extensions import assert_never from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] +PoolingTask = Literal["encode", "embed", "classify", "score"] class PoolingType(IntEnum): @@ -64,6 +67,48 @@ def from_config_with_defaults( ) +class Pooler(nn.Module, ABC): + """The interface required for all poolers used in pooling models in vLLM.""" + + @staticmethod + def from_config_with_defaults( + pooler_config: PoolerConfig, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[list[int]] = None, + ) -> "Pooler": + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=pooling_type, + normalize=normalize, + softmax=softmax, + step_tag_id=step_tag_id, + returned_token_ids=returned_token_ids, + ) + + if pooling_type == PoolingType.STEP: + return StepPooler.from_config(resolved_config) + + return SimplePooler.from_config(resolved_config) + + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + """ + Construct the pooling parameters to use for a task, + or `None` if the task is not supported. + """ + return None + + @abstractmethod + def forward( + self, + hidden_states: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + def get_prompt_lens( hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, @@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput: return PoolerOutput(outputs=all_outputs) -class BasePooler(nn.Module): - - @abstractmethod - def forward( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - raise NotImplementedError - - class PoolingMethod(nn.Module, ABC): @staticmethod @@ -130,6 +164,10 @@ def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": raise NotImplementedError(f"Unsupported method: {pooling_type}") + @abstractmethod + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + raise NotImplementedError + @abstractmethod def forward_one( self, @@ -168,6 +206,14 @@ def forward( class CLSPool(PoolingMethod): + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + # The equalities are split up to keep mypy happy + if (task == "encode" or task == "embed" or task == "classify" + or task == "score"): + return PoolingParams() + + assert_never(task) + def forward_one( self, hidden_states: torch.Tensor, @@ -190,6 +236,14 @@ def forward_all( class LastPool(PoolingMethod): + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + # The equalities are split up to keep mypy happy + if (task == "encode" or task == "embed" or task == "classify" + or task == "score"): + return PoolingParams() + + assert_never(task) + def forward_one( self, hidden_states: torch.Tensor, @@ -208,6 +262,16 @@ def forward_all( class AllPool(PoolingMethod): + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "encode": + return PoolingParams() + + # The equalities are split up to keep mypy happy + if task == "embed" or task == "classify" or task == "score": + return None + + assert_never(task) + def forward_one( self, hidden_states: torch.Tensor, @@ -235,6 +299,14 @@ def forward_all( class MeanPool(PoolingMethod): + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + # The equalities are split up to keep mypy happy + if (task == "encode" or task == "embed" or task == "classify" + or task == "score"): + return PoolingParams() + + assert_never(task) + def forward_one( self, hidden_states: torch.Tensor, @@ -345,25 +417,6 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - @classmethod - def from_config_with_defaults( - cls, - pooler_config: PoolerConfig, - pooling_type: PoolingType, - normalize: bool, - softmax: bool, - ) -> "PoolerHead": - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( - pooler_config=pooler_config, - pooling_type=pooling_type, - normalize=normalize, - softmax=softmax, - step_tag_id=None, - returned_token_ids=None, - ) - - return cls.from_config(resolved_config) - @classmethod def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": if pooler_config.normalize and pooler_config.softmax: @@ -424,21 +477,17 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], return self.activation(pooled_data) -class SimplePooler(BasePooler): +class SimplePooler(Pooler): """A layer that pools specific information from hidden states. This layer does the following: 1. Extracts specific tokens or aggregates data based on pooling method. 2. Normalizes output if specified. 3. Returns structured results as `PoolerOutput`. - - Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. """ @classmethod - def from_config_with_defaults( + def from_config_with_defaults( # type: ignore[override] cls, pooler_config: PoolerConfig, pooling_type: PoolingType, @@ -471,6 +520,9 @@ def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: self.pooling = pooling self.head = head + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + return self.pooling.get_pooling_params(task) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], @@ -481,7 +533,7 @@ def forward( return build_output(pooled_data) -class StepPooler(BasePooler): +class StepPooler(Pooler): @classmethod def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": @@ -543,6 +595,16 @@ def extract_states( return pooled_data + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "encode": + return PoolingParams(logits_processing_needs_token_ids=True) + + # The equalities are split up to keep mypy happy + if task == "embed" or task == "classify" or task == "score": + return None + + assert_never(task) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], @@ -553,32 +615,6 @@ def forward( return build_output(pooled_data) -class Pooler(nn.Module): - - @staticmethod - def from_config_with_defaults( - pooler_config: PoolerConfig, - pooling_type: PoolingType, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ) -> BasePooler: - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( - pooler_config=pooler_config, - pooling_type=pooling_type, - normalize=normalize, - softmax=softmax, - step_tag_id=step_tag_id, - returned_token_ids=returned_token_ids, - ) - - if pooling_type == PoolingType.STEP: - return StepPooler.from_config(resolved_config) - - return SimplePooler.from_config(resolved_config) - - PoolingFn = Callable[ [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], Union[torch.Tensor, list[torch.Tensor]]] @@ -618,6 +654,18 @@ def _get_act_fn(self, use_cross_encoder: bool): return (self.cross_encoder_act_fn if use_cross_encoder else self.classification_act_fn) + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "encode": + return PoolingParams() + if task == "embed": + return None + if task == "classify": + return PoolingParams() + if task == "score": + return PoolingParams(use_cross_encoder=True) + + assert_never(task) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 5c09ac306052..f319c0c4441a 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn @@ -42,13 +42,14 @@ def _create_pooling_model_cls( default_softmax: bool, ) -> _T: # Lazy import - from vllm.model_executor.layers.pooler import Pooler, PoolerOutput - from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.layers.pooler import Pooler from .utils import AutoWeightsLoader, WeightsMapper class ModelForPooling(orig_cls, VllmModelForPooling): + is_pooling_model = True + def __init__( self, *, @@ -66,27 +67,20 @@ def __init__( delattr(self, attr) # If the model already defines a pooler instance, don't overwrite it - if not getattr(self, "_pooler", None): + if not getattr(self, "pooler", None): self._init_pooler(vllm_config, prefix=prefix) def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=default_pooling_type, normalize=default_normalize, softmax=default_softmax, ) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # TODO: Support uninitialized params tracking @@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import (ClassifierPooler, - PoolerOutput, PoolingType, - SimplePooler) + PoolingType, SimplePooler) from vllm.model_executor.models.interfaces import SupportsCrossEncoding - from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -213,7 +205,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): softmax=True, ) - self._pooler = ClassifierPooler( + self.pooler = ClassifierPooler( vllm_config.model_config, pooling=pooler.pooling, classifier=self._classifier, @@ -234,13 +226,6 @@ def forward( return super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) - def pooler( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) method = getattr(self.config, "method", None) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 65e6428f4912..bd4445c49a03 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -18,12 +18,14 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, - PoolingMethod, PoolingType) + PoolingMethod, PoolingTask, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.pooling_params import PoolingParams +from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -80,7 +82,7 @@ def forward( return embeddings -class BertPooler(nn.Module): +class BertPooler(Pooler): def __init__(self, config: BertConfig): super().__init__() @@ -89,6 +91,9 @@ def __init__(self, config: BertConfig): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + return self.pooling.get_pooling_params(task) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], @@ -319,6 +324,9 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module, SupportsQuant): + + is_pooling_model = True + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} def __init__(self, @@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + pooler_config = vllm_config.model_config.pooler_config self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self._pooler = self._build_pooler(pooler_config) + self.pooler = self._build_pooler(pooler_config) def forward( self, @@ -422,13 +433,6 @@ def forward( inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) @@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -476,7 +482,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): embedding_class=BertEmbedding, add_pooling_layer=True) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler( + self.pooler = ClassifierPooler( vllm_config.model_config, pooling=self.bert.pooler, classifier=self.classifier, @@ -487,13 +493,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loaded_params = loader.load_weights(weights) return loaded_params - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 27021550f998..82883bfa890d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -40,9 +40,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from ..layers.pooler import Pooler, PoolingType from .interfaces import SupportsPP @@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module): _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -339,7 +340,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "gpt2")) self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, normalize=False, @@ -349,13 +350,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index dfec8a51c4c2..ba0e22892d86 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from array import array -from typing import Optional import torch import torch.nn as nn @@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): - "<|user|>\nPROMPT\n<|assistant|>\n" """ + is_pooling_model = True + def __init__( self, vllm_config: VllmConfig, @@ -214,11 +215,4 @@ def __init__( super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self._pooler = GritLMPooler(vllm_config.model_config) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + self.pooler = GritLMPooler(vllm_config.model_config) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9655bdf6f3e3..417f90594497 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -119,13 +119,6 @@ def get_input_embeddings( ... -# We can't use runtime_checkable with ClassVar for issubclass checks -# so we need to treat the class as an instance and use isinstance instead -@runtime_checkable -class _SupportsMultiModalType(Protocol): - supports_multimodal: Literal[True] - - @overload def supports_multimodal( model: type[object]) -> TypeIs[type[SupportsMultiModal]]: @@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: def supports_multimodal( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: - if isinstance(model, type): - return isinstance(model, _SupportsMultiModalType) - - return isinstance(model, SupportsMultiModal) + return getattr(model, "supports_multimodal", False) @runtime_checkable @@ -174,13 +164,6 @@ def post_process_tokens(cls, prompt: TokensPrompt) -> None: ... -# We can't use runtime_checkable with ClassVar for issubclass checks -# so we need to treat the class as an instance and use isinstance instead -@runtime_checkable -class _SupportsScoreTemplateType(Protocol): - supports_score_template: Literal[True] - - @overload def supports_score_template( model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]: @@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: def supports_score_template( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]: - - if isinstance(model, type): - return isinstance(model, _SupportsScoreTemplateType) - - return isinstance(model, SupportsScoreTemplate) + return getattr(model, "supports_score_template", False) @runtime_checkable @@ -409,11 +388,6 @@ class HasInnerState(Protocol): """ -@runtime_checkable -class _HasInnerStateType(Protocol): - has_inner_state: ClassVar[Literal[True]] - - @overload def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: def has_inner_state( model: Union[type[object], object] ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: - if isinstance(model, type): - return isinstance(model, _HasInnerStateType) - - return isinstance(model, HasInnerState) + return getattr(model, "has_inner_state", False) @runtime_checkable @@ -446,11 +417,6 @@ class IsAttentionFree(Protocol): """ -@runtime_checkable -class _IsAttentionFreeType(Protocol): - is_attention_free: ClassVar[Literal[True]] - - @overload def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ... @@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: def is_attention_free( model: Union[type[object], object] ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: - if isinstance(model, type): - return isinstance(model, _IsAttentionFreeType) - - return isinstance(model, IsAttentionFree) + return getattr(model, "is_attention_free", False) @runtime_checkable @@ -502,11 +465,6 @@ def get_mamba_state_shape_from_config( ... -@runtime_checkable -class _IsHybridType(Protocol): - is_hybrid: ClassVar[Literal[True]] - - @overload def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... @@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: def is_hybrid( model: Union[type[object], object] ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: - if isinstance(model, type): - return isinstance(model, _IsHybridType) - - return isinstance(model, IsHybrid) + return getattr(model, "is_hybrid", False) @runtime_checkable @@ -598,11 +553,6 @@ class HasNoOps(Protocol): has_noops: ClassVar[Literal[True]] = True -@runtime_checkable -class _HasNoOpsType(Protocol): - has_noops: ClassVar[Literal[True]] - - @overload def has_noops(model: object) -> TypeIs[HasNoOps]: ... @@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: def has_noops( model: Union[type[object], object] ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: - if isinstance(model, type): - return isinstance(model, _HasNoOpsType) - - return isinstance(model, HasNoOps) + return getattr(model, "has_noops", False) @runtime_checkable @@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: def _supports_cross_encoding( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: - - if isinstance(model, type): - return isinstance(model, SupportsCrossEncoding) - - return isinstance(model, SupportsCrossEncoding) + return getattr(model, "supports_cross_encoding", False) def supports_cross_encoding( @@ -658,8 +601,9 @@ def supports_cross_encoding( def has_step_pooler(model: Union[type[object], object]) -> bool: """Check if the model uses step pooler.""" - return is_pooling_model(model) and any( - type(module).__name__ == "StepPooler" for module in model.modules()) + from vllm.model_executor.layers.pooler import StepPooler + + return is_pooling_model(model) and isinstance(model.pooler, StepPooler) class SupportsQuant: @@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: def supports_transcription( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: - if isinstance(model, type): - return isinstance(model, SupportsTranscription) - - return isinstance(model, SupportsTranscription) + return getattr(model, "supports_transcription", False) @runtime_checkable @@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: def supports_v0_only( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: - if isinstance(model, type): - return isinstance(model, SupportsV0Only) - - return isinstance(model, SupportsV0Only) + return getattr(model, "supports_v0_only", False) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4a1ea74a218a..4d68227b2af8 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, - runtime_checkable) +from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, + Union, overload, runtime_checkable) import torch import torch.nn as nn @@ -13,8 +12,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.model_executor.layers.pooler import PoolerOutput - from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.sampling_metadata import SamplingMetadata logger = init_logger(__name__) @@ -130,16 +128,20 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForPooling(VllmModel[T], Protocol[T]): +class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): """The interface required for all pooling models in vLLM.""" - def pooler( - self, - hidden_states: T, - pooling_metadata: "PoolingMetadata", - ) -> "PoolerOutput": - """Only called on TP rank 0.""" - ... + is_pooling_model: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pooling. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + pooler: "Pooler" + """The pooler is only called on TP rank 0.""" @overload @@ -158,7 +160,4 @@ def is_pooling_model( if not is_vllm_model(model): return False - if isinstance(model, type): - return isinstance(model, VllmModelForPooling) - - return isinstance(model, VllmModelForPooling) + return getattr(model, "is_pooling_model", False) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e8549b4e0538..d9bbee0a2463 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -28,9 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, @@ -404,6 +403,8 @@ def load_weights(self, weights: Iterable[tuple[str, class InternLM2ForRewardModel(InternLM2ForCausalLM): + is_pooling_model = True + def __init__( self, *, @@ -428,7 +429,7 @@ def __init__( ) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.ALL, normalize=False, @@ -446,10 +447,3 @@ def forward( inputs_embeds) logits, _ = self.v_head(hidden_states) return logits - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 233c222963be..e95f3491c6b6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -27,9 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, @@ -563,6 +562,8 @@ def _is_moe_layer(name: str): class JambaForSequenceClassification(JambaForCausalLM): + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -590,16 +591,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): softmax=False, ) - self._pooler = ClassifierPooler( + self.pooler = ClassifierPooler( vllm_config.model_config, pooling=pooler.pooling, classifier=self.score, act_fn=pooler.head.activation, ) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 78e58896e0d8..6b191b09b4bf 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -13,9 +13,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import (SupportsCrossEncoding, SupportsMultiModal, SupportsScoreTemplate) @@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, SupportsCrossEncoding, SupportsMultiModal, SupportsScoreTemplate): + + is_pooling_model = True weight_mapper = WeightsMapper( orig_to_new_prefix={ "score.0.": "score.dense.", @@ -95,7 +96,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.score = JinaVLScorer(config) - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, normalize=False, @@ -137,14 +138,6 @@ def forward( logits = self.score(hidden_states) - self.LOGIT_BIAS return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.weight_mapper) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index e094ff163572..94a7ddcc01c9 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -13,14 +13,16 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler, - PoolingMethod, PoolingType) +from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, + PoolingMethod, PoolingTask, + PoolingType) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.pooling_params import PoolingParams +from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -253,7 +255,7 @@ def forward( return norm_outputs -class ModernBertPooler(BasePooler): +class ModernBertPooler(Pooler): def __init__(self, config: ModernBertConfig): super().__init__() @@ -268,6 +270,9 @@ def __init__(self, config: ModernBertConfig): eps=config.norm_eps, bias=config.norm_bias) + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + return self.pooling.get_pooling_params(task) + def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], @@ -281,6 +286,8 @@ def forward( class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, SupportsCrossEncoding): + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -288,7 +295,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler( + self.pooler = ClassifierPooler( vllm_config.model_config, pooling=ModernBertPooler(config), classifier=self.classifier, @@ -321,13 +328,6 @@ def weight_filter(): default_weight_loader) weight_loader(param, loaded_weight) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: Optional[torch.LongTensor], diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80ec..d51fcec07fd6 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -24,12 +24,13 @@ from transformers import BatchFeature from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, + PoolerIdentity, SimplePooler) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, SupportsMultiModal, SupportsV0Only) from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs) @@ -37,8 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import (IntermediateTensors, PoolerOutput, - PoolingSequenceGroupOutput) +from vllm.sequence import IntermediateTensors class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): @@ -116,7 +116,9 @@ def apply( dummy_inputs=PrithviGeoSpatialMAEInputBuilder) class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, SupportsV0Only): - """ Prithvi Masked Autoencoder""" + """Prithvi Masked Autoencoder""" + + is_pooling_model = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -162,6 +164,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): "Only SemanticSegmentationTask is supported for now " "by PrithviGeospatialMAE.") + self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) + def _parse_and_validate_multimodal_data( self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -189,7 +193,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - pixel_values, location_coords = ( self._parse_and_validate_multimodal_data(**kwargs)) model_output = self.model(pixel_values, @@ -197,13 +200,6 @@ def forward( return model_output.output - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_list = [] diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 9a8508081678..58f95d6eebfb 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -16,8 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2Model @@ -25,6 +24,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): + + is_pooling_model = True + pooler: SimplePooler + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -61,7 +64,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, return_bias=False), ) - self._pooler: SimplePooler self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -80,13 +82,6 @@ def forward( logits = self.score(hidden_states) return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, @@ -96,11 +91,11 @@ def load_weights(self, weights: Iterable[tuple[str, class Qwen2ForRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config, prefix=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.ALL, normalize=False, @@ -109,11 +104,11 @@ def __init__(self, *, vllm_config, prefix=""): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config, prefix=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( + self.pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.STEP, normalize=False, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 55ebb6e9e2a4..7d3b56ced5c4 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -15,8 +15,7 @@ from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix) -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel from .interfaces import SupportsCrossEncoding, SupportsV0Only @@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True jina_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ 'emb_ln': "embeddings.LayerNorm", @@ -188,7 +188,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) - self._pooler = ClassifierPooler( + self.pooler = ClassifierPooler( vllm_config.model_config, pooling=CLSPool(), classifier=self.classifier, @@ -198,13 +198,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 106f3e8b22b7..1a7305727e11 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import msgspec @@ -15,24 +15,31 @@ class PoolingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] - """API parameters for pooling models. This is currently a placeholder. + """API parameters for pooling models. This Attributes: dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. - additional_data: Any additional data needed for pooling. """ dimensions: Optional[int] = None + use_cross_encoder: bool = False - additional_data: Optional[Any] = None + """Internal use only.""" + + logits_processing_needs_token_ids: bool = False + """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams(dimensions=self.dimensions, - use_cross_encoder=self.use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams( + dimensions=self.dimensions, + use_cross_encoder=self.use_cross_encoder, + logits_processing_needs_token_ids=self. + logits_processing_needs_token_ids, + ) def verify(self, model_config: "ModelConfig") -> None: if self.dimensions is not None: @@ -54,10 +61,12 @@ def verify(self, model_config: "ModelConfig") -> None: raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: - return (f"PoolingParams(" - f"dimensions={self.dimensions}, " - f"use_cross_encoder={self.use_cross_encoder}, " - f"additional_metadata={self.additional_data})") + return ( + f"PoolingParams(" + f"dimensions={self.dimensions}, " + f"use_cross_encoder={self.use_cross_encoder}, " + f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})" + ) def __post_init__(self) -> None: assert self.output_kind == RequestOutputKind.FINAL_ONLY,\