Skip to content

Commit 90bd2ab

Browse files
[Model] Update pooling model interface (#21058)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 9fb2d22 commit 90bd2ab

File tree

17 files changed

+247
-345
lines changed

17 files changed

+247
-345
lines changed

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from vllm.model_executor.layers.pooler import Pooler, PoolingType
1212
from vllm.model_executor.models.gemma2 import Gemma2Model
1313
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
14-
from vllm.model_executor.pooling_metadata import PoolingMetadata
15-
from vllm.sequence import IntermediateTensors, PoolerOutput
14+
from vllm.sequence import IntermediateTensors
1615

1716

1817
class MyGemma2Embedding(nn.Module):
18+
19+
is_pooling_model = True
20+
1921
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
2022

2123
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -24,7 +26,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
2426
self.model = Gemma2Model(vllm_config=vllm_config,
2527
prefix=maybe_prefix(prefix, "model"))
2628

27-
self._pooler = Pooler.from_config_with_defaults(
29+
self.pooler = Pooler.from_config_with_defaults(
2830
vllm_config.model_config.pooler_config,
2931
pooling_type=PoolingType.LAST,
3032
normalize=True,
@@ -54,13 +56,6 @@ def forward(
5456
# Return all-zero embeddings
5557
return torch.zeros_like(hidden_states)
5658

57-
def pooler(
58-
self,
59-
hidden_states: torch.Tensor,
60-
pooling_metadata: PoolingMetadata,
61-
) -> Optional[PoolerOutput]:
62-
return self._pooler(hidden_states, pooling_metadata)
63-
6459
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
6560

6661
weights = self.hf_to_vllm_mapper.apply(weights)

vllm/entrypoints/openai/protocol.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
12371237
user: Optional[str] = None
12381238
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
12391239

1240-
# --8<-- [start:embedding-pooling-params]
1241-
additional_data: Optional[Any] = None
1242-
# --8<-- [end:embedding-pooling-params]
1243-
12441240
# --8<-- [start:embedding-extra-params]
12451241
add_special_tokens: bool = Field(
12461242
default=True,
@@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
12591255
# --8<-- [end:embedding-extra-params]
12601256

12611257
def to_pooling_params(self):
1262-
return PoolingParams(dimensions=self.dimensions,
1263-
additional_data=self.additional_data)
1258+
return PoolingParams(dimensions=self.dimensions)
12641259

12651260

12661261
class EmbeddingChatRequest(OpenAIBaseModel):
@@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
12721267
user: Optional[str] = None
12731268
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
12741269

1275-
# --8<-- [start:chat-embedding-pooling-params]
1276-
additional_data: Optional[Any] = None
1277-
# --8<-- [end:chat-embedding-pooling-params]
1278-
12791270
# --8<-- [start:chat-embedding-extra-params]
12801271
add_special_tokens: bool = Field(
12811272
default=False,
@@ -1323,8 +1314,7 @@ def check_generation_prompt(cls, data):
13231314
return data
13241315

13251316
def to_pooling_params(self):
1326-
return PoolingParams(dimensions=self.dimensions,
1327-
additional_data=self.additional_data)
1317+
return PoolingParams(dimensions=self.dimensions)
13281318

13291319

13301320
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
@@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
13401330
text_2: Union[list[str], str, ScoreMultiModalParam]
13411331
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
13421332

1343-
# --8<-- [start:score-pooling-params]
1344-
additional_data: Optional[Any] = None
1345-
# --8<-- [end:score-pooling-params]
1346-
13471333
# --8<-- [start:score-extra-params]
13481334

13491335
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
13621348
# --8<-- [end:score-extra-params]
13631349

13641350
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1365-
return PoolingParams(use_cross_encoder=use_cross_encoder,
1366-
additional_data=self.additional_data)
1351+
return PoolingParams(use_cross_encoder=use_cross_encoder)
13671352

13681353

13691354
class RerankRequest(OpenAIBaseModel):
@@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
13731358
top_n: int = Field(default_factory=lambda: 0)
13741359
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
13751360

1376-
# --8<-- [start:rerank-pooling-params]
1377-
additional_data: Optional[Any] = None
1378-
# --8<-- [end:rerank-pooling-params]
1379-
13801361
# --8<-- [start:rerank-extra-params]
13811362

13821363
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
13951376
# --8<-- [end:rerank-extra-params]
13961377

13971378
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1398-
return PoolingParams(use_cross_encoder=use_cross_encoder,
1399-
additional_data=self.additional_data)
1379+
return PoolingParams(use_cross_encoder=use_cross_encoder)
14001380

14011381

14021382
class RerankDocument(BaseModel):
@@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
15341514
truncate_prompt_tokens: Optional[int] = None
15351515
user: Optional[str] = None
15361516

1537-
# --8<-- [start:classification-pooling-params]
1538-
additional_data: Optional[Any] = None
1539-
# --8<-- [end:classification-pooling-params]
1540-
15411517
# --8<-- [start:classification-extra-params]
15421518
priority: int = Field(
15431519
default=0,
@@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
15501526
# --8<-- [end:classification-extra-params]
15511527

15521528
def to_pooling_params(self):
1553-
return PoolingParams(additional_data=self.additional_data)
1529+
return PoolingParams()
15541530

15551531

15561532
class ClassificationData(OpenAIBaseModel):

0 commit comments

Comments
 (0)