Skip to content

[Model] Update pooling model interface #21058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors


class MyGemma2Embedding(nn.Module):
Expand All @@ -24,7 +23,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
Expand Down Expand Up @@ -54,13 +53,6 @@ def forward(
# Return all-zero embeddings
return torch.zeros_like(hidden_states)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

weights = self.hf_to_vllm_mapper.apply(weights)
Expand Down
148 changes: 84 additions & 64 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Literal, Optional, TypeVar, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never

from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingTask = Literal["encode", "embed", "classify", "score"]


class PoolingType(IntEnum):
Expand Down Expand Up @@ -64,6 +67,48 @@ def from_config_with_defaults(
)


class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""

@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "Pooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)

if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)

return SimplePooler.from_config(resolved_config)

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intended use of get_pooling_params()? Will it get called from serving_embedding.py somehow?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be called by:

  • LLMEngine (and its async version) to validate that the request is supported by the model.
  • The model runner, in order to get information such as use_cross_encoder and logits_processing_needs_token_ids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task will be set by our code at API level

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

  • Score API: We set task="score"
  • LLMEngine: Call get_pooling_params with the task to see if it's supported
  • Model runner: Call get_pooling_params to pass use_cross_encoder to the pooler.

This abstraction lets each model define how to handle each task, instead of having static logic at the API level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is good, we're starting to accumulate too much logic at the entrypoint level.

Just to understand the last detail: is EmbeddingCompetionRequest.to_pooling_params() going to be replaced with something like EmbeddingCompetionRequest.to_pooling_task()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, since we still have some parameters (e.g. dimensions) that need to be forwarded. I will add a task attribute to PoolingParams so that the task can be set in to_pooling_params

"""
Construct the pooling parameters to use for a task,
or `None` if the task is not support.
"""
return None

@abstractmethod
def forward(
self,
hidden_states: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError


def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
Expand Down Expand Up @@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
return PoolerOutput(outputs=all_outputs)


class BasePooler(nn.Module):

@abstractmethod
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError


class PoolingMethod(nn.Module, ABC):

@staticmethod
Expand Down Expand Up @@ -345,25 +379,6 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:

class PoolerHead(nn.Module):

@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "PoolerHead":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=None,
returned_token_ids=None,
)

return cls.from_config(resolved_config)

@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
if pooler_config.normalize and pooler_config.softmax:
Expand Down Expand Up @@ -424,21 +439,17 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
return self.activation(pooled_data)


class SimplePooler(BasePooler):
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.

This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.

Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""

@classmethod
def from_config_with_defaults(
def from_config_with_defaults( # type: ignore[override]
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
Expand Down Expand Up @@ -471,6 +482,19 @@ def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
self.pooling = pooling
self.head = head

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()

# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
if isinstance(self.pooling, (LastPool, CLSPool, MeanPool)):
return PoolingParams()

return None

assert_never(task)

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand All @@ -481,7 +505,7 @@ def forward(
return build_output(pooled_data)


class StepPooler(BasePooler):
class StepPooler(Pooler):

@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
Expand Down Expand Up @@ -543,6 +567,16 @@ def extract_states(

return pooled_data

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)

# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None

assert_never(task)

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand All @@ -553,32 +587,6 @@ def forward(
return build_output(pooled_data)


class Pooler(nn.Module):

@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> BasePooler:
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)

if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)

return SimplePooler.from_config(resolved_config)


PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
Expand Down Expand Up @@ -618,6 +626,18 @@ def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task == "embed":
return None
if task == "classify":
return PoolingParams()
if task == "score":
return PoolingParams(use_cross_encoder=True)

assert_never(task)

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand Down
27 changes: 5 additions & 22 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,8 +42,7 @@ def _create_pooling_model_cls(
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler

from .utils import AutoWeightsLoader, WeightsMapper

Expand Down Expand Up @@ -73,20 +72,13 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking

Expand Down Expand Up @@ -171,10 +163,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
PoolingType, SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors

from .utils import maybe_prefix
Expand Down Expand Up @@ -213,7 +203,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
softmax=True,
)

self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
Expand All @@ -234,13 +224,6 @@ def forward(
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)

def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
Expand Down
Loading