|
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | import torch.nn.functional as F
|
| 10 | +from typing_extensions import assert_never |
10 | 11 |
|
11 | 12 | from vllm.config import VllmConfig
|
12 | 13 | from vllm.logger import init_logger
|
13 |
| -from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, |
| 14 | +from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask, |
| 15 | + PoolingType, |
14 | 16 | extract_vision_tokens_kernel)
|
15 | 17 | # yapf: disable
|
16 | 18 | from vllm.model_executor.pooling_metadata import (
|
17 | 19 | PoolingMetadata as V0PoolingMetadata)
|
18 | 20 | from vllm.model_executor.pooling_metadata import PoolingTensors
|
19 | 21 | # yapf: enable
|
20 | 22 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
| 23 | +from vllm.pooling_params import PoolingParams |
21 | 24 | from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
22 | 25 | from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
23 | 26 |
|
|
36 | 39 | PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
37 | 40 |
|
38 | 41 |
|
39 |
| -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, |
40 |
| - info=Qwen2VLProcessingInfo, |
41 |
| - dummy_inputs=Qwen2VLDummyInputsBuilder) |
42 |
| -class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, |
43 |
| - SupportsCrossEncoding, SupportsMultiModal): |
44 |
| - # Weight mapping for HuggingFace checkpoint compatibility |
45 |
| - weight_mapper = WeightsMapper( |
46 |
| - orig_to_new_prefix={ |
47 |
| - "model.": "language_model.model.", |
48 |
| - "visual.": "visual.", |
49 |
| - "lm_head.": "language_model.lm_head.", |
50 |
| - }) |
51 |
| - |
52 |
| - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
53 |
| - super().__init__(vllm_config=vllm_config, |
54 |
| - prefix=maybe_prefix(prefix, "qwen2_vl")) |
| 42 | +class JinaVLPooler(Pooler): |
| 43 | + """Vision-aware pooler for Jina V4 with special vision token handling.""" |
55 | 44 |
|
| 45 | + def __init__(self, |
| 46 | + vllm_config: VllmConfig, |
| 47 | + pooling_backend: str = "pytorch"): |
| 48 | + super().__init__() |
56 | 49 | self.hidden_size = vllm_config.model_config.hf_config.hidden_size
|
57 |
| - pooler_config = vllm_config.model_config.pooler_config |
| 50 | + self.pooling_backend = pooling_backend |
58 | 51 | self.observability_config = vllm_config.observability_config
|
59 | 52 |
|
60 |
| - # Configuration for vision pooling backend |
61 |
| - self.pooling_backend = getattr(vllm_config.model_config, |
62 |
| - "jina_pooling_backend", "pytorch") |
63 |
| - if self.pooling_backend not in ("triton", "pytorch"): |
64 |
| - logger.warning( |
65 |
| - "Invalid jina_pooling_backend '%s'. " |
66 |
| - "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", |
67 |
| - self.pooling_backend) |
68 |
| - self.pooling_backend = "pytorch" |
| 53 | + # Performance tracking |
| 54 | + self._pooling_time_ms = 0.0 |
| 55 | + self._pooling_count = 0 |
69 | 56 |
|
70 | 57 | # Initialize base pooler for fallback
|
| 58 | + pooler_config = vllm_config.model_config.pooler_config |
71 | 59 | self._base_pooler = Pooler.from_config_with_defaults(
|
72 | 60 | pooler_config,
|
73 | 61 | pooling_type=PoolingType.MEAN,
|
74 | 62 | normalize=True,
|
75 | 63 | softmax=False)
|
76 | 64 |
|
77 |
| - # Performance tracking |
78 |
| - self._pooling_time_ms = 0.0 |
79 |
| - self._pooling_count = 0 |
| 65 | + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: |
| 66 | + """Return pooling params for embedding task.""" |
| 67 | + if task == "embed": |
| 68 | + return PoolingParams() |
80 | 69 |
|
81 |
| - logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") |
| 70 | + # The equalities are split up to keep mypy happy |
| 71 | + if task == "encode" or task == "classify" or task == "score": |
| 72 | + return None |
| 73 | + |
| 74 | + assert_never(task) |
| 75 | + |
| 76 | + def forward( |
| 77 | + self, |
| 78 | + hidden_states: Union[torch.Tensor, list[torch.Tensor]], |
| 79 | + pooling_metadata: PoolingMetadata, |
| 80 | + ) -> PoolerOutput: |
| 81 | + """Apply vision-aware pooling to hidden states.""" |
| 82 | + start_time = time.time() if self.observability_config else None |
| 83 | + |
| 84 | + # Validate inputs |
| 85 | + if hidden_states is None or hidden_states.numel() == 0: |
| 86 | + logger.warning("Empty hidden states received") |
| 87 | + return PoolerOutput(outputs=[]) |
| 88 | + |
| 89 | + # Extract token IDs safely from metadata |
| 90 | + token_ids_list, seq_ids = self._extract_token_ids_safe( |
| 91 | + pooling_metadata) |
| 92 | + |
| 93 | + if not token_ids_list: |
| 94 | + logger.warning("No valid sequences found for pooling") |
| 95 | + # Fallback to base pooler |
| 96 | + return self._base_pooler(hidden_states, pooling_metadata) |
| 97 | + |
| 98 | + # Get prompt lengths based on metadata type |
| 99 | + if isinstance(pooling_metadata, V1PoolingMetadata): |
| 100 | + prompt_lens = pooling_metadata.prompt_lens |
| 101 | + else: |
| 102 | + prompt_lens = PoolingTensors.from_pooling_metadata( |
| 103 | + pooling_metadata, hidden_states.device).prompt_lens |
| 104 | + |
| 105 | + # Validate lengths match |
| 106 | + assert len(token_ids_list) == len(prompt_lens), ( |
| 107 | + f"Mismatch: {len(token_ids_list)} sequences vs " |
| 108 | + f"{len(prompt_lens)} lengths") |
| 109 | + |
| 110 | + # Apply pooling based on configured backend |
| 111 | + if self.pooling_backend == "triton": |
| 112 | + pooled_data = self._apply_vision_pooling_optimized( |
| 113 | + hidden_states, token_ids_list, prompt_lens) |
| 114 | + else: # self.pooling_backend == "pytorch" |
| 115 | + pooled_data = self._apply_vision_pooling_pytorch( |
| 116 | + hidden_states, token_ids_list, prompt_lens) |
| 117 | + |
| 118 | + # Build output |
| 119 | + pooled_outputs = [ |
| 120 | + PoolingSequenceGroupOutput(data) for data in pooled_data |
| 121 | + ] |
| 122 | + |
| 123 | + # Record metrics |
| 124 | + if self.observability_config: |
| 125 | + elapsed_ms = (time.time() - start_time) * 1000 |
| 126 | + self._pooling_time_ms += elapsed_ms |
| 127 | + self._pooling_count += 1 |
| 128 | + |
| 129 | + if self._pooling_count % 100 == 0: |
| 130 | + avg_time = self._pooling_time_ms / self._pooling_count |
| 131 | + logger.debug("Average pooling time: %.2fms", avg_time) |
| 132 | + |
| 133 | + return PoolerOutput(outputs=pooled_outputs) |
82 | 134 |
|
83 | 135 | def _extract_token_ids_safe(
|
84 | 136 | self, pooling_metadata: PoolingMetadata
|
@@ -239,64 +291,41 @@ def _apply_vision_pooling_pytorch(
|
239 | 291 |
|
240 | 292 | return pooled_outputs
|
241 | 293 |
|
242 |
| - def pooler( |
243 |
| - self, |
244 |
| - hidden_states: torch.Tensor, |
245 |
| - pooling_metadata: PoolingMetadata, |
246 |
| - ) -> Optional[PoolerOutput]: |
247 |
| - """Thread-safe pooler with production error handling.""" |
248 |
| - start_time = time.time() if self.observability_config else None |
249 |
| - |
250 |
| - # Validate inputs |
251 |
| - if hidden_states is None or hidden_states.numel() == 0: |
252 |
| - logger.warning("Empty hidden states received") |
253 |
| - return PoolerOutput(outputs=[]) |
254 |
| - |
255 |
| - # Extract token IDs safely from metadata |
256 |
| - token_ids_list, seq_ids = self._extract_token_ids_safe( |
257 |
| - pooling_metadata) |
258 | 294 |
|
259 |
| - if not token_ids_list: |
260 |
| - logger.warning("No valid sequences found for pooling") |
261 |
| - # Fallback to base pooler |
262 |
| - return self._base_pooler(hidden_states, pooling_metadata) |
263 |
| - |
264 |
| - # Get prompt lengths based on metadata type |
265 |
| - if isinstance(pooling_metadata, V1PoolingMetadata): |
266 |
| - prompt_lens = pooling_metadata.prompt_lens |
267 |
| - else: |
268 |
| - prompt_lens = PoolingTensors.from_pooling_metadata( |
269 |
| - pooling_metadata, hidden_states.device).prompt_lens |
| 295 | +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, |
| 296 | + info=Qwen2VLProcessingInfo, |
| 297 | + dummy_inputs=Qwen2VLDummyInputsBuilder) |
| 298 | +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, |
| 299 | + SupportsCrossEncoding, SupportsMultiModal): |
270 | 300 |
|
271 |
| - # Validate lengths match |
272 |
| - assert len(token_ids_list) == len(prompt_lens), ( |
273 |
| - f"Mismatch: {len(token_ids_list)} sequences vs " |
274 |
| - f"{len(prompt_lens)} lengths") |
| 301 | + is_pooling_model = True |
275 | 302 |
|
276 |
| - # Apply pooling based on configured backend |
277 |
| - if self.pooling_backend == "triton": |
278 |
| - pooled_data = self._apply_vision_pooling_optimized( |
279 |
| - hidden_states, token_ids_list, prompt_lens) |
280 |
| - else: # self.pooling_backend == "pytorch" |
281 |
| - pooled_data = self._apply_vision_pooling_pytorch( |
282 |
| - hidden_states, token_ids_list, prompt_lens) |
| 303 | + # Weight mapping for HuggingFace checkpoint compatibility |
| 304 | + weight_mapper = WeightsMapper( |
| 305 | + orig_to_new_prefix={ |
| 306 | + "model.": "language_model.model.", |
| 307 | + "visual.": "visual.", |
| 308 | + "lm_head.": "language_model.lm_head.", |
| 309 | + }) |
283 | 310 |
|
284 |
| - # Build output |
285 |
| - pooled_outputs = [ |
286 |
| - PoolingSequenceGroupOutput(data) for data in pooled_data |
287 |
| - ] |
| 311 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 312 | + super().__init__(vllm_config=vllm_config, |
| 313 | + prefix=maybe_prefix(prefix, "qwen2_vl")) |
288 | 314 |
|
289 |
| - # Record metrics |
290 |
| - if self.observability_config: |
291 |
| - elapsed_ms = (time.time() - start_time) * 1000 |
292 |
| - self._pooling_time_ms += elapsed_ms |
293 |
| - self._pooling_count += 1 |
| 315 | + # Configuration for vision pooling backend |
| 316 | + self.pooling_backend = getattr(vllm_config.model_config, |
| 317 | + "jina_pooling_backend", "pytorch") |
| 318 | + if self.pooling_backend not in ("triton", "pytorch"): |
| 319 | + logger.warning( |
| 320 | + "Invalid jina_pooling_backend '%s'. " |
| 321 | + "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", |
| 322 | + self.pooling_backend) |
| 323 | + self.pooling_backend = "pytorch" |
294 | 324 |
|
295 |
| - if self._pooling_count % 100 == 0: |
296 |
| - avg_time = self._pooling_time_ms / self._pooling_count |
297 |
| - logger.debug("Average pooling time: %.2fms", avg_time) |
| 325 | + # Initialize the vision-aware pooler |
| 326 | + self.pooler = JinaVLPooler(vllm_config, self.pooling_backend) |
298 | 327 |
|
299 |
| - return PoolerOutput(outputs=pooled_outputs) |
| 328 | + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") |
300 | 329 |
|
301 | 330 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
302 | 331 | """Load weights with validation and error handling."""
|
|
0 commit comments