-
-
Notifications
You must be signed in to change notification settings - Fork 9k
[Model] Update pooling model interface #21058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Update pooling model interface #21058
Conversation
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the pooling model interface by changing pooler
from a method to a Pooler
instance. The changes are consistent across the codebase and improve the interface design. I've identified a critical type inconsistency in the new Pooler
abstract base class and a related issue with ClassifierPooler
's inheritance that should be addressed to ensure type safety and interface correctness.
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the cleanup!
Wait a moment Can we rename pooler? How about pooling, anything except pooler. I plan to discuss this issue in a later PR, but it seems that discussing it here now is the best choice . |
Why can't this be named Pooler? As long as weights are mapped properly this shouldn't matter right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can wait for @noooop to show why the pooler name would cause problems but otherwise this PR LGTM.
In principle, we should not use the same name for things that are often used together, as this can easily cause confusion. There is another specific issue: #20930 can automatically handle the model's default_pooling_type, After, as_embedding_model can theoretically handle BertModel (and other BERT-like models) correctly However, running as_embedding_model(BertModel) will not succeed because there is a name conflict between BertModel.pooler and VllmModelForPooling.pooler (both before and now). Therefore, we also need to wrap BertModel again with BertEmbeddingModel. vllm/vllm/model_executor/models/bert.py Line 395 in 01513a3
In other words, if we rename the VllmModelForPooling.pooler, we would no longer need wrappers like BertEmbeddingModel, as_embedding_model will support all models. |
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Could you clarify a bit? Should |
And isn't this also a problem before this PR? Since both are named |
BertModel has a pooler module class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None <- here Previously, BertModel.pooler conflicted with ModelForPooling.pooler method. class ModelForPooling(orig_cls, VllmModelForPooling):
.....
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata) Now conflicts with Pooler instance class ModelForPooling(orig_cls, VllmModelForPooling):
.....
def __init__(
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
..... In other words, if we rename the VllmModelForPooling.pooler, we would no longer need wrappers like BertEmbeddingModel, as_embedding_model will support BertModel. |
|
It feels weird. |
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Is anyone able to repro the CI failure in language models test? It passes for me locally... |
I was suspecting that somehow |
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Yeah it seems to happen for me on Python 3.11 but not on Python 3.10 for whatever reason... fixing now |
In my case it fails and |
Can you pull the latest commit and try again? See if it works now |
Instead of using |
It works now. |
class ModelForPooling(orig_cls, VllmModelForPooling): | ||
|
||
is_pooling_model = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this achieved already by deriving VllmModelForPooling
?
In the doc we also say that all pooling models implement VllmModelForPooling
but not all do. Could this be cause of confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VllmModelForPooling
is only an interface, we don't explicitly derive from it just like how generative models don't explicitly derive from VllmModelForTextGeneration
|
||
return SimplePooler.from_config(resolved_config) | ||
|
||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the intended use of get_pooling_params()
? Will it get called from serving_embedding.py
somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be called by:
LLMEngine
(and its async version) to validate that the request is supported by the model.- The model runner, in order to get information such as
use_cross_encoder
andlogits_processing_needs_token_ids
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The task
will be set by our code at API level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example:
- Score API: We set
task="score"
LLMEngine
: Callget_pooling_params
with the task to see if it's supported- Model runner: Call
get_pooling_params
to passuse_cross_encoder
to the pooler.
This abstraction lets each model define how to handle each task, instead of having static logic at the API level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is good, we're starting to accumulate too much logic at the entrypoint level.
Just to understand the last detail: is EmbeddingCompetionRequest.to_pooling_params()
going to be replaced with something like EmbeddingCompetionRequest.to_pooling_task()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, since we still have some parameters (e.g. dimensions
) that need to be forwarded. I will add a task
attribute to PoolingParams
so that the task can be set in to_pooling_params
@noooop and @DarkLight1337 , the wrapper classes are useful to place custom logic that is needed to handle the ideosyncrasies of Bert-like models. For example in #19988 I'm using the RobertaEmbeddingModel class to fix the position ids. I moved the logic out of the RobertaEmbedding class because I couldn't find a way to make it work with cuda graphs. |
…ecture Update JinaVLForEmbedding to align with PR vllm-project#21058's pooling model interface: - Add is_pooling_model = True class attribute - Create JinaVLPooler class inheriting from Pooler base class - Move vision-aware pooling logic into JinaVLPooler - Implement get_pooling_params method returning PoolingParams() for "embed" task - Replace pooler method with pooler attribute - Add required imports: PoolingTask, PoolingParams, assert_never The JinaVLPooler maintains the sophisticated vision-text pooling behavior while conforming to the new architecture requirements. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Himanshu Jaju <hj@mistral.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Update the
VllmModelForPooling
interface so thatpooler
now must be aPooler
instance instead of a method. This enables the model runner to directly fetch information from thePooler
instance in subsequent PRs.cc @maxdebayser @noooop
Test Plan
The existing tests should pass
Test Result
(Optional) Documentation Update