diff --git a/examples/offline_inference/embed_jina_embeddings_v4.py b/examples/offline_inference/embed_jina_embeddings_v4.py new file mode 100644 index 00000000000..69ebe83d758 --- /dev/null +++ b/examples/offline_inference/embed_jina_embeddings_v4.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.multimodal.utils import fetch_image + + +def get_embeddings(outputs): + """Extract and normalize embeddings from model outputs.""" + VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653 + + embeddings = [] + for output in outputs: + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # For vision inputs, extract only vision token embeddings + img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings_tensor = output.outputs.data.detach().clone()[ + img_start_pos : img_end_pos + 1 + ] + else: + # For text-only inputs, use all token embeddings + embeddings_tensor = output.outputs.data.detach().clone() + + # Pool and normalize embeddings + pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32) + embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) + return embeddings + + +def main(): + # Initialize the model + model = LLM( + model="jinaai/jina-embeddings-v4-vllm-retrieval", + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + ) + + # Example 1: Text-only embeddings + print("=== Text Embeddings ===") + query = "Overview of climate change impacts on coastal cities" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + passage = """The impacts of climate change on coastal cities are significant + and multifaceted. Rising sea levels threaten infrastructure, while increased + storm intensity poses risks to populations and economies.""" + passage_prompt = TextPrompt(prompt=f"Passage: {passage}") + + # Generate embeddings + text_outputs = model.encode([query_prompt, passage_prompt]) + text_embeddings = get_embeddings(text_outputs) + + # Calculate similarity + similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item() + print(f"Query: {query[:50]}...") + print(f"Passage: {passage[:50]}...") + print(f"Similarity: {similarity:.4f}\n") + + # Example 2: Image embeddings + print("=== Image Embeddings ===") + # Fetch sample images + image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + + image1 = fetch_image(image1_url) + image2 = fetch_image(image2_url) + + # Create image prompts with the required format + image1_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image1}, + ) + + image2_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image2}, + ) + + # Generate embeddings + image_outputs = model.encode([image1_prompt, image2_prompt]) + image_embeddings = get_embeddings(image_outputs) + + # Calculate similarity + similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item() + print(f"Image 1: {image1_url.split('/')[-1]}") + print(f"Image 2: {image2_url.split('/')[-1]}") + print(f"Similarity: {similarity:.4f}\n") + + # Example 3: Cross-modal similarity (text vs image) + print("=== Cross-modal Similarity ===") + query = "scientific paper with markdown formatting" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + # Generate embeddings for text query and second image + cross_outputs = model.encode([query_prompt, image2_prompt]) + cross_embeddings = get_embeddings(cross_outputs) + + # Calculate cross-modal similarity + similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item() + print(f"Text query: {query}") + print(f"Image: {image2_url.split('/')[-1]}") + print(f"Cross-modal similarity: {similarity:.4f}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/pooling/test_jina_embeddings_v4.py b/tests/models/pooling/test_jina_embeddings_v4.py new file mode 100644 index 00000000000..35c84acc2ec --- /dev/null +++ b/tests/models/pooling/test_jina_embeddings_v4.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import time +from array import array +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import torch +from PIL import Image + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.sequence import SequenceData + +model_name = "jinaai/jina-embeddings-v4-vllm-retrieval" + +# Vision token IDs +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + + +@pytest.fixture(scope="module") +def model(): + """Initialize model once for all tests.""" + return LLM( + model=model_name, + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", + normalize=False), + dtype="float16", + max_model_len=2048, + ) + + +def extract_embeddings(output): + """Extract embeddings based on token type.""" + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # Extract vision tokens only + img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings = output.outputs.data[img_start:img_end + 1] + else: + # Use all tokens for text + embeddings = output.outputs.data + + # Mean pool and normalize + pooled = embeddings.mean(dim=0, dtype=torch.float32) + return torch.nn.functional.normalize(pooled, dim=-1) + + +class TestBasicFunctionality: + """Test basic embedding generation functionality.""" + + def test_text_only_embeddings(self, model): + """Test text-only embedding generation.""" + prompts = [ + TextPrompt(prompt="Query: What is machine learning?"), + TextPrompt(prompt="Passage: Machine learning is a subset of " + "artificial intelligence.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # Check embeddings are normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Check similarity is reasonable + similarity = torch.dot(embeddings[0], embeddings[1]).item() + assert 0.0 <= similarity <= 1.0 + + def test_image_embeddings(self, model): + """Test image embedding generation.""" + # Create a dummy image + image = Image.new('RGB', (224, 224), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + + # Check embedding is normalized + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + # Check dimension + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size + + def test_mixed_batch(self, model): + """Test mixed text and image batch processing.""" + image = Image.new('RGB', (224, 224), color='blue') + + prompts = [ + TextPrompt(prompt="Query: blue color"), + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ), + TextPrompt(prompt="Passage: The sky is blue.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # All embeddings should be normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Text query about blue should have some similarity to blue image + text_image_sim = torch.dot(embeddings[0], embeddings[1]).item() + assert text_image_sim > 0.0 # Should have positive similarity + + +class TestThreadSafety: + """Test thread safety of the pooling implementation.""" + + def test_concurrent_requests(self, model): + """Test handling of concurrent embedding requests.""" + num_threads = 4 + requests_per_thread = 5 + + def process_request(thread_id): + results = [] + for i in range(requests_per_thread): + prompt = TextPrompt( + prompt=f"Query from thread {thread_id}, request {i}") + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + results.append(embedding) + return results + + # Run concurrent requests + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(process_request, i) for i in range(num_threads) + ] + + all_results = [] + for future in as_completed(futures): + thread_results = future.result() + all_results.extend(thread_results) + + # Verify all embeddings are valid + assert len(all_results) == num_threads * requests_per_thread + for emb in all_results: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_input_handling(self, model): + """Test handling of empty inputs.""" + # This should not crash but return empty outputs + outputs = model.encode([]) + assert len(outputs) == 0 + + def test_very_long_sequence(self, model): + """Test handling of sequences near max length.""" + # Create a long text that approaches max_model_len + long_text = " ".join(["word"] * 1000) + prompt = TextPrompt(prompt=f"Query: {long_text}") + + # Should handle gracefully without crashing + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + def test_invalid_image_format(self, model): + """Test handling of invalid image inputs.""" + # Create an invalid image (too small) + tiny_image = Image.new('RGB', (1, 1), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": tiny_image}, + ) + + # Should handle gracefully + try: + outputs = model.encode([prompt]) + # If it doesn't crash, check output is valid + if outputs: + embedding = extract_embeddings(outputs[0]) + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size + except Exception as e: + # Should provide meaningful error message + assert "image" in str(e).lower() or "size" in str(e).lower() + + +class TestMemoryManagement: + """Test memory management and cleanup.""" + + def test_memory_cleanup(self, model): + """Test that memory is properly cleaned up after processing.""" + # Get initial memory usage + torch.cuda.empty_cache() + if torch.cuda.is_available(): + initial_memory = torch.cuda.memory_allocated() + + # Process multiple large batches + for _ in range(5): + prompts = [ + TextPrompt(prompt=f"Query: test {i}") for i in range(10) + ] + outputs = model.encode(prompts) + del outputs + gc.collect() + + # Check memory usage hasn't grown significantly + if torch.cuda.is_available(): + torch.cuda.empty_cache() + final_memory = torch.cuda.memory_allocated() + memory_growth = final_memory - initial_memory + # Allow some growth but not excessive + assert memory_growth < 100 * 1024 * 1024 # Less than 100MB growth + + +class TestPerformance: + """Test performance characteristics.""" + + def test_pooling_performance(self, model): + """Test that custom pooling is performant.""" + # Create test prompts + text_prompts = [ + TextPrompt(prompt=f"Query: test {i}") for i in range(10) + ] + + # Time text-only pooling + start_time = time.time() + text_outputs = model.encode(text_prompts) + text_time = time.time() - start_time + + # Create image prompts + image = Image.new('RGB', (224, 224), color='green') + image_prompts = [ + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe.<|im_end|>\n", + multi_modal_data={"image": image}, + ) for _ in range(10) + ] + + # Time vision pooling + start_time = time.time() + image_outputs = model.encode(image_prompts) + image_time = time.time() - start_time + + # Vision pooling should not be significantly slower + # (allowing 2x slower due to additional processing) + assert image_time < text_time * 2.0 + + # Verify outputs are valid + for output in text_outputs + image_outputs: + embedding = extract_embeddings(output) + assert torch.allclose(embedding.norm(), + torch.tensor(1.0), + atol=1e-3) + + +class TestPoolingMetadataIntegration: + """Test proper integration with PoolingMetadata.""" + + def test_seq_data_access(self): + """Test that token IDs are properly accessible via seq_data.""" + # Create mock sequence data + prompt_tokens = array('l', [ + 101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, + VISION_END_TOKEN_ID, 104 + ]) + seq_data = SequenceData(prompt_tokens) + + # Verify prompt_token_ids_array property works + assert hasattr(seq_data, 'prompt_token_ids_array') + retrieved_tokens = seq_data.prompt_token_ids_array + assert list(retrieved_tokens) == list(prompt_tokens) + + # Verify vision tokens can be detected + token_tensor = torch.tensor(list(retrieved_tokens)) + vision_mask = ((token_tensor >= VISION_START_TOKEN_ID) & + (token_tensor <= VISION_END_TOKEN_ID)) + assert vision_mask.any() + assert vision_mask.sum() == 3 # Start, middle, end tokens + + +class TestAccuracyValidation: + """Test accuracy against expected behavior.""" + + @pytest.mark.parametrize("text", [ + "Short text", + "A much longer text that contains multiple sentences for testing", + "特殊字符测试 🚀 emoji test", "Numbers 12345 and symbols !@#$%" + ]) + def test_text_embedding_consistency(self, model, text): + """Test that same text produces consistent embeddings.""" + prompt = TextPrompt(prompt=f"Query: {text}") + + # Generate embeddings multiple times + embeddings = [] + for _ in range(3): + outputs = model.encode([prompt]) + emb = extract_embeddings(outputs[0]) + embeddings.append(emb) + + # All should be identical + for i in range(1, len(embeddings)): + assert torch.allclose(embeddings[0], embeddings[i], atol=1e-5) + + def test_vision_only_pooling(self, model): + """Test that vision pooling extracts only vision tokens.""" + # Create an image with known characteristics + image = Image.new('RGB', (224, 224), color='red') + + # Two prompts with same image but different text + prompt1 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Red image<|im_end|>\n", + multi_modal_data={"image": image}, + ) + prompt2 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Blue sky green grass<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt1, prompt2]) + emb1 = extract_embeddings(outputs[0]) + emb2 = extract_embeddings(outputs[1]) + + # Since both use the same image and vision-only pooling, + # embeddings should be very similar despite different text + similarity = torch.dot(emb1, emb2).item() + assert similarity > 0.99 # Should be nearly identical + + +class TestVisionPooler: + """Test the VisionPooler class.""" + + def test_vision_pooler(self): + """Test that the VisionPooler correctly pools vision tokens.""" + from vllm.config import ModelConfig + from vllm.model_executor.layers.pooler import VisionPooler + from vllm.pooling_params import PoolingParams + from vllm.v1.pool.metadata import PoolingMetadata + + model_config = ModelConfig(model_name, task="embed") + model_config.hf_config.vision_start_token_id = VISION_START_TOKEN_ID + model_config.hf_config.vision_end_token_id = VISION_END_TOKEN_ID + model_config.hidden_size = 4 + + pooler = VisionPooler(model_config) + + hidden_states = torch.randn(10, 4) + prompt_token_ids = torch.tensor([[ + 1, 2, VISION_START_TOKEN_ID, 4, VISION_END_TOKEN_ID, 6, 7, 8, 9, 10 + ]]) + prompt_lens = torch.tensor([10]) + + pooling_metadata = PoolingMetadata(prompt_lens=prompt_lens, + prompt_token_ids=prompt_token_ids, + pooling_params=[PoolingParams()]) + + output = pooler.forward(hidden_states, pooling_metadata) + + vision_tokens = hidden_states[2:5] + expected_output = vision_tokens.mean(dim=0) + + assert torch.allclose(output.outputs[0].data, + expected_output, + atol=1e-5) diff --git a/tests/models/registry.py b/tests/models/registry.py index 2adfa859a1c..8d2f0c1fe40 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -316,6 +316,8 @@ def check_available_online( "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 # [Multimodal] + "JinaVLForEmbedding": _HfExamplesInfo("jinaai/jina-embeddings-v4-vllm-retrieval", # noqa: E501 + trust_remote_code=True), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), diff --git a/vllm/config.py b/vllm/config.py index 526b5db235f..f919a3f5463 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3258,7 +3258,8 @@ def get_limit_per_prompt(self, modality: str) -> int: class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" - pooling_type: Optional[str] = None + pooling_type: Optional[Literal["last", "all", "cls", "step", "mean", + "vision"]] = None """ The pooling method of the pooling model. This should be a key in [`vllm.model_executor.layers.pooler.PoolingType`][]. diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 74916492f57..34337aa6cde 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -17,6 +17,7 @@ from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata @@ -31,6 +32,7 @@ class PoolingType(IntEnum): CLS = 2 STEP = 3 MEAN = 4 + VISION = 5 @dataclass(frozen=True) @@ -90,6 +92,8 @@ def from_config_with_defaults( if pooling_type == PoolingType.STEP: return StepPooler.from_config(resolved_config) + if pooling_type == PoolingType.VISION: + return VisionPooler.from_config(resolved_config) return SimplePooler.from_config(resolved_config) @@ -621,6 +625,35 @@ def forward( ClassifierFn = Callable[[torch.Tensor], torch.Tensor] +if HAS_TRITON: + + @triton.jit + def mean_pool_with_position_kernel( + hidden_states_ptr, + output_ptr, + seq_start, + seq_len, + hidden_size, + pool_start, + pool_end, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(0) + + if pid >= hidden_size: + return + + accumulator = 0.0 + for i in range(pool_start, pool_end): + hidden_val = tl.load(hidden_states_ptr + + (seq_start + i) * hidden_size + pid) + accumulator += hidden_val + + # Store mean pooled result + result = accumulator / (pool_end - pool_start) + tl.store(output_ptr + pid, result) + + class ClassifierPooler(nn.Module): """A pooling layer for classification tasks. @@ -706,3 +739,62 @@ def forward( ]) return build_output(scores) + + +class VisionPooler(Pooler): + + @classmethod + def from_config(cls, model_config: ModelConfig) -> "VisionPooler": + return cls(model_config) + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "embed": + return PoolingParams(pooling_type="vision", + logits_processing_needs_token_ids=True) + return None + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + assert isinstance(pooling_metadata, V1PoolingMetadata) + + pooled_outputs = [] + for i in range(len(pooling_metadata.prompt_lens)): + start_pos = (pooling_metadata.prompt_token_ids[i] == + self.config.hf_config.vision_start_token_id). + nonzero()[-1].item() + end_pos = (pooling_metadata.prompt_token_ids[i] == + self.config.hf_config.vision_end_token_id). + nonzero()[-1].item() + + seq_start = torch.cumsum( + torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), + dim=0)[i] + seq_len = pooling_metadata.prompt_lens[i] + + output = torch.empty(self.config.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype) + + grid = lambda meta: (self.config.hidden_size, ) + if HAS_TRITON: + mean_pool_with_position_kernel[grid](hidden_states, output, + seq_start, seq_len, + self.config.hidden_size, + start_pos, end_pos + 1) + else: + # Fallback to PyTorch implementation if Triton is not available + vision_tokens_range = hidden_states[seq_start + start_pos : seq_start + end_pos + 1] + output = vision_tokens_range.mean(dim=0) + + pooled_outputs.append(output) + + return build_output(torch.stack(pooled_outputs)) + + diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py new file mode 100644 index 00000000000..f97420c56f6 --- /dev/null +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import Pooler, PoolingTask, VisionPooler +# yapf: disable +from vllm.model_executor.pooling_metadata import ( + PoolingMetadata as V0PoolingMetadata) +# yapf: enable +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata + +from .interfaces import SupportsCrossEncoding, SupportsMultiModal +from .qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .utils import maybe_prefix + +logger = init_logger(__name__) + +# Vision token IDs for Jina V4 +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + +PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] + + +class JinaVLPooler(Pooler): + def __init__(self, vllm_config: VllmConfig): + super().__init__() + self.vision_pooler = VisionPooler(vllm_config.model_config) + + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + return self.vision_pooler.get_pooling_params(task) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self.vision_pooler.forward(hidden_states, pooling_metadata) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, SupportsMultiModal): + + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "qwen2_vl")) + + self.pooler = JinaVLPooler(vllm_config) + + logger.info("Initialized JinaVLForEmbedding with vision-aware pooling") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 52fdb910891..1d8d81fc2bc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -170,6 +170,8 @@ # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + # Multimodal embedding model with token-type-aware pooling + "JinaVLForEmbedding": ("jina_embeddings_v4", "JinaVLForEmbedding"), } _CROSS_ENCODER_MODELS = {