Skip to content

Commit 8e0578a

Browse files
author
Sigrid Jin (Sionic AI)
committed
refactor: update JinaVLForEmbedding to comply with new pooling architecture
Update JinaVLForEmbedding to align with PR vllm-project#21058's pooling model interface: - Add is_pooling_model = True class attribute - Create JinaVLPooler class inheriting from Pooler base class - Move vision-aware pooling logic into JinaVLPooler - Implement get_pooling_params method returning PoolingParams() for "embed" task - Replace pooler method with pooler attribute - Add required imports: PoolingTask, PoolingParams, assert_never The JinaVLPooler maintains the sophisticated vision-text pooling behavior while conforming to the new architecture requirements. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 9d34781 commit 8e0578a

File tree

1 file changed

+111
-82
lines changed

1 file changed

+111
-82
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 111 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77

88
import torch
99
import torch.nn.functional as F
10+
from typing_extensions import assert_never
1011

1112
from vllm.config import VllmConfig
1213
from vllm.logger import init_logger
13-
from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType,
14+
from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask,
15+
PoolingType,
1416
extract_vision_tokens_kernel)
1517
# yapf: disable
1618
from vllm.model_executor.pooling_metadata import (
1719
PoolingMetadata as V0PoolingMetadata)
1820
from vllm.model_executor.pooling_metadata import PoolingTensors
1921
# yapf: enable
2022
from vllm.multimodal import MULTIMODAL_REGISTRY
23+
from vllm.pooling_params import PoolingParams
2124
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
2225
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2326

@@ -36,49 +39,98 @@
3639
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
3740

3841

39-
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
40-
info=Qwen2VLProcessingInfo,
41-
dummy_inputs=Qwen2VLDummyInputsBuilder)
42-
class JinaVLForEmbedding(Qwen2VLForConditionalGeneration,
43-
SupportsCrossEncoding, SupportsMultiModal):
44-
# Weight mapping for HuggingFace checkpoint compatibility
45-
weight_mapper = WeightsMapper(
46-
orig_to_new_prefix={
47-
"model.": "language_model.model.",
48-
"visual.": "visual.",
49-
"lm_head.": "language_model.lm_head.",
50-
})
51-
52-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
53-
super().__init__(vllm_config=vllm_config,
54-
prefix=maybe_prefix(prefix, "qwen2_vl"))
42+
class JinaVLPooler(Pooler):
43+
"""Vision-aware pooler for Jina V4 with special vision token handling."""
5544

45+
def __init__(self,
46+
vllm_config: VllmConfig,
47+
pooling_backend: str = "pytorch"):
48+
super().__init__()
5649
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
57-
pooler_config = vllm_config.model_config.pooler_config
50+
self.pooling_backend = pooling_backend
5851
self.observability_config = vllm_config.observability_config
5952

60-
# Configuration for vision pooling backend
61-
self.pooling_backend = getattr(vllm_config.model_config,
62-
"jina_pooling_backend", "pytorch")
63-
if self.pooling_backend not in ("triton", "pytorch"):
64-
logger.warning(
65-
"Invalid jina_pooling_backend '%s'. "
66-
"Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.",
67-
self.pooling_backend)
68-
self.pooling_backend = "pytorch"
53+
# Performance tracking
54+
self._pooling_time_ms = 0.0
55+
self._pooling_count = 0
6956

7057
# Initialize base pooler for fallback
58+
pooler_config = vllm_config.model_config.pooler_config
7159
self._base_pooler = Pooler.from_config_with_defaults(
7260
pooler_config,
7361
pooling_type=PoolingType.MEAN,
7462
normalize=True,
7563
softmax=False)
7664

77-
# Performance tracking
78-
self._pooling_time_ms = 0.0
79-
self._pooling_count = 0
65+
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
66+
"""Return pooling params for embedding task."""
67+
if task == "embed":
68+
return PoolingParams()
8069

81-
logger.info("Initialized JinaVLForEmbedding with thread-safe pooling")
70+
# The equalities are split up to keep mypy happy
71+
if task == "encode" or task == "classify" or task == "score":
72+
return None
73+
74+
assert_never(task)
75+
76+
def forward(
77+
self,
78+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
79+
pooling_metadata: PoolingMetadata,
80+
) -> PoolerOutput:
81+
"""Apply vision-aware pooling to hidden states."""
82+
start_time = time.time() if self.observability_config else None
83+
84+
# Validate inputs
85+
if hidden_states is None or hidden_states.numel() == 0:
86+
logger.warning("Empty hidden states received")
87+
return PoolerOutput(outputs=[])
88+
89+
# Extract token IDs safely from metadata
90+
token_ids_list, seq_ids = self._extract_token_ids_safe(
91+
pooling_metadata)
92+
93+
if not token_ids_list:
94+
logger.warning("No valid sequences found for pooling")
95+
# Fallback to base pooler
96+
return self._base_pooler(hidden_states, pooling_metadata)
97+
98+
# Get prompt lengths based on metadata type
99+
if isinstance(pooling_metadata, V1PoolingMetadata):
100+
prompt_lens = pooling_metadata.prompt_lens
101+
else:
102+
prompt_lens = PoolingTensors.from_pooling_metadata(
103+
pooling_metadata, hidden_states.device).prompt_lens
104+
105+
# Validate lengths match
106+
assert len(token_ids_list) == len(prompt_lens), (
107+
f"Mismatch: {len(token_ids_list)} sequences vs "
108+
f"{len(prompt_lens)} lengths")
109+
110+
# Apply pooling based on configured backend
111+
if self.pooling_backend == "triton":
112+
pooled_data = self._apply_vision_pooling_optimized(
113+
hidden_states, token_ids_list, prompt_lens)
114+
else: # self.pooling_backend == "pytorch"
115+
pooled_data = self._apply_vision_pooling_pytorch(
116+
hidden_states, token_ids_list, prompt_lens)
117+
118+
# Build output
119+
pooled_outputs = [
120+
PoolingSequenceGroupOutput(data) for data in pooled_data
121+
]
122+
123+
# Record metrics
124+
if self.observability_config:
125+
elapsed_ms = (time.time() - start_time) * 1000
126+
self._pooling_time_ms += elapsed_ms
127+
self._pooling_count += 1
128+
129+
if self._pooling_count % 100 == 0:
130+
avg_time = self._pooling_time_ms / self._pooling_count
131+
logger.debug("Average pooling time: %.2fms", avg_time)
132+
133+
return PoolerOutput(outputs=pooled_outputs)
82134

83135
def _extract_token_ids_safe(
84136
self, pooling_metadata: PoolingMetadata
@@ -239,64 +291,41 @@ def _apply_vision_pooling_pytorch(
239291

240292
return pooled_outputs
241293

242-
def pooler(
243-
self,
244-
hidden_states: torch.Tensor,
245-
pooling_metadata: PoolingMetadata,
246-
) -> Optional[PoolerOutput]:
247-
"""Thread-safe pooler with production error handling."""
248-
start_time = time.time() if self.observability_config else None
249-
250-
# Validate inputs
251-
if hidden_states is None or hidden_states.numel() == 0:
252-
logger.warning("Empty hidden states received")
253-
return PoolerOutput(outputs=[])
254-
255-
# Extract token IDs safely from metadata
256-
token_ids_list, seq_ids = self._extract_token_ids_safe(
257-
pooling_metadata)
258294

259-
if not token_ids_list:
260-
logger.warning("No valid sequences found for pooling")
261-
# Fallback to base pooler
262-
return self._base_pooler(hidden_states, pooling_metadata)
263-
264-
# Get prompt lengths based on metadata type
265-
if isinstance(pooling_metadata, V1PoolingMetadata):
266-
prompt_lens = pooling_metadata.prompt_lens
267-
else:
268-
prompt_lens = PoolingTensors.from_pooling_metadata(
269-
pooling_metadata, hidden_states.device).prompt_lens
295+
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
296+
info=Qwen2VLProcessingInfo,
297+
dummy_inputs=Qwen2VLDummyInputsBuilder)
298+
class JinaVLForEmbedding(Qwen2VLForConditionalGeneration,
299+
SupportsCrossEncoding, SupportsMultiModal):
270300

271-
# Validate lengths match
272-
assert len(token_ids_list) == len(prompt_lens), (
273-
f"Mismatch: {len(token_ids_list)} sequences vs "
274-
f"{len(prompt_lens)} lengths")
301+
is_pooling_model = True
275302

276-
# Apply pooling based on configured backend
277-
if self.pooling_backend == "triton":
278-
pooled_data = self._apply_vision_pooling_optimized(
279-
hidden_states, token_ids_list, prompt_lens)
280-
else: # self.pooling_backend == "pytorch"
281-
pooled_data = self._apply_vision_pooling_pytorch(
282-
hidden_states, token_ids_list, prompt_lens)
303+
# Weight mapping for HuggingFace checkpoint compatibility
304+
weight_mapper = WeightsMapper(
305+
orig_to_new_prefix={
306+
"model.": "language_model.model.",
307+
"visual.": "visual.",
308+
"lm_head.": "language_model.lm_head.",
309+
})
283310

284-
# Build output
285-
pooled_outputs = [
286-
PoolingSequenceGroupOutput(data) for data in pooled_data
287-
]
311+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
312+
super().__init__(vllm_config=vllm_config,
313+
prefix=maybe_prefix(prefix, "qwen2_vl"))
288314

289-
# Record metrics
290-
if self.observability_config:
291-
elapsed_ms = (time.time() - start_time) * 1000
292-
self._pooling_time_ms += elapsed_ms
293-
self._pooling_count += 1
315+
# Configuration for vision pooling backend
316+
self.pooling_backend = getattr(vllm_config.model_config,
317+
"jina_pooling_backend", "pytorch")
318+
if self.pooling_backend not in ("triton", "pytorch"):
319+
logger.warning(
320+
"Invalid jina_pooling_backend '%s'. "
321+
"Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.",
322+
self.pooling_backend)
323+
self.pooling_backend = "pytorch"
294324

295-
if self._pooling_count % 100 == 0:
296-
avg_time = self._pooling_time_ms / self._pooling_count
297-
logger.debug("Average pooling time: %.2fms", avg_time)
325+
# Initialize the vision-aware pooler
326+
self.pooler = JinaVLPooler(vllm_config, self.pooling_backend)
298327

299-
return PoolerOutput(outputs=pooled_outputs)
328+
logger.info("Initialized JinaVLForEmbedding with thread-safe pooling")
300329

301330
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
302331
"""Load weights with validation and error handling."""

0 commit comments

Comments
 (0)