From 9fbc0e936fc2795162c613a244a8fb8fddbcb467 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 15:34:51 +0900 Subject: [PATCH 01/23] feat: jina support Signed-off-by: Sigrid Jin (Sionic AI) --- benchmarks/jina_embeddings_v4_validation.py | 262 +++++++++++++ .../offline_inference/jina_embeddings_v4.py | 120 ++++++ .../multimodal/test_jina_embeddings_v4.py | 331 ++++++++++++++++ .../models/jina_embeddings_v4.py | 363 ++++++++++++++++++ vllm/model_executor/models/registry.py | 2 + 5 files changed, 1078 insertions(+) create mode 100644 benchmarks/jina_embeddings_v4_validation.py create mode 100644 examples/offline_inference/jina_embeddings_v4.py create mode 100644 tests/models/multimodal/test_jina_embeddings_v4.py create mode 100644 vllm/model_executor/models/jina_embeddings_v4.py diff --git a/benchmarks/jina_embeddings_v4_validation.py b/benchmarks/jina_embeddings_v4_validation.py new file mode 100644 index 000000000000..b89c4753381c --- /dev/null +++ b/benchmarks/jina_embeddings_v4_validation.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark and validate Jina Embeddings V4 against HuggingFace implementation. + +This script compares embeddings generated by vLLM vs HuggingFace to ensure +accuracy and measure performance differences. +""" + +import argparse +import time +from typing import List, Tuple + +import numpy as np +import torch +from PIL import Image +from transformers import AutoModel, AutoProcessor + +from vllm import LLM +from vllm.config import PoolerConfig + +# Vision token IDs +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 +from vllm.inputs.data import TextPrompt + + +def create_test_cases() -> List[Tuple[str, str, any]]: + """Create comprehensive test cases for validation.""" + test_cases = [] + + # Text-only test cases + test_cases.extend([ + ("text", "Query: What is artificial intelligence?", None), + ("text", "Passage: AI is a field of computer science focusing on creating intelligent machines.", None), + ("text", "Query: 你好世界", None), # Chinese text + ("text", "Passage: " + " ".join(["word"] * 100), None), # Long text + ]) + + # Image test cases + for color in ["red", "green", "blue"]: + img = Image.new('RGB', (224, 224), color=color) + test_cases.append(("image", f"{color} image", img)) + + # Complex image + complex_img = Image.new('RGB', (224, 224)) + pixels = complex_img.load() + for i in range(224): + for j in range(224): + pixels[i, j] = (i % 256, j % 256, (i+j) % 256) + test_cases.append(("image", "complex pattern", complex_img)) + + return test_cases + + +def compute_hf_embeddings( + model_name: str, + test_cases: List[Tuple[str, str, any]] +) -> List[torch.Tensor]: + """Compute embeddings using HuggingFace implementation.""" + print("Loading HuggingFace model...") + model = AutoModel.from_pretrained( + model_name, + trust_remote_code=True, + torch_dtype=torch.float16 + ).cuda().eval() + + processor = AutoProcessor.from_pretrained( + model_name, + trust_remote_code=True + ) + + embeddings = [] + + print("Computing HuggingFace embeddings...") + start_time = time.time() + + for case_type, text, image in test_cases: + if case_type == "text": + inputs = processor(text=text, return_tensors="pt").to("cuda") + else: # image + inputs = processor( + text="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + images=image, + return_tensors="pt" + ).to("cuda") + + with torch.no_grad(): + outputs = model(**inputs) + # Extract embeddings based on model output structure + if hasattr(outputs, 'embeddings'): + embedding = outputs.embeddings[0] + else: + # Fallback to last hidden state with custom pooling + hidden_states = outputs.last_hidden_state[0] + + # Apply token-type-aware pooling + input_ids = inputs['input_ids'][0] + vision_mask = ( + (input_ids >= VISION_START_TOKEN_ID) & + (input_ids <= VISION_END_TOKEN_ID) + ) + + if vision_mask.any(): + embedding = hidden_states[vision_mask].mean(dim=0) + else: + embedding = hidden_states.mean(dim=0) + + embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) + + embeddings.append(embedding.cpu()) + + hf_time = time.time() - start_time + print(f"HuggingFace processing time: {hf_time:.2f}s") + + return embeddings + + +def compute_vllm_embeddings( + model_name: str, + test_cases: List[Tuple[str, str, any]] +) -> List[torch.Tensor]: + """Compute embeddings using vLLM implementation.""" + print("\nLoading vLLM model...") + model = LLM( + model=model_name, + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + ) + + embeddings = [] + prompts = [] + + # Prepare prompts + for case_type, text, image in test_cases: + if case_type == "text": + prompt = TextPrompt(prompt=text) + else: # image + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + prompts.append(prompt) + + print("Computing vLLM embeddings...") + start_time = time.time() + + # Process all at once for better performance + outputs = model.encode(prompts) + + for output in outputs: + # Extract based on token type + if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID + img_start = output.prompt_token_ids.index(151652) + img_end = output.prompt_token_ids.index(151653) + embedding_data = output.outputs.data[img_start:img_end + 1] + else: + embedding_data = output.outputs.data + + # Pool and normalize + pooled = embedding_data.mean(dim=0, dtype=torch.float32) + normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1) + embeddings.append(normalized.cpu()) + + vllm_time = time.time() - start_time + print(f"vLLM processing time: {vllm_time:.2f}s") + + return embeddings + + +def compare_embeddings( + hf_embeddings: List[torch.Tensor], + vllm_embeddings: List[torch.Tensor], + test_cases: List[Tuple[str, str, any]] +) -> None: + """Compare embeddings and report differences.""" + print("\n" + "="*60) + print("EMBEDDING COMPARISON RESULTS") + print("="*60) + + similarities = [] + max_diffs = [] + + for i, (case_type, desc, _) in enumerate(test_cases): + hf_emb = hf_embeddings[i] + vllm_emb = vllm_embeddings[i] + + # Compute cosine similarity + similarity = torch.nn.functional.cosine_similarity( + hf_emb.unsqueeze(0), + vllm_emb.unsqueeze(0) + ).item() + + # Compute max absolute difference + max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item() + + similarities.append(similarity) + max_diffs.append(max_diff) + + print(f"\nTest case {i+1}: {case_type} - {desc[:50]}...") + print(f" Cosine similarity: {similarity:.6f}") + print(f" Max absolute diff: {max_diff:.6f}") + print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}") + + # Flag significant differences + if similarity < 0.99: + print(f" ⚠️ WARNING: Low similarity detected!") + + # Summary statistics + print("\n" + "-"*60) + print("SUMMARY STATISTICS") + print("-"*60) + print(f"Average cosine similarity: {np.mean(similarities):.6f}") + print(f"Min cosine similarity: {np.min(similarities):.6f}") + print(f"Max absolute difference: {np.max(max_diffs):.6f}") + + # Overall assessment + if np.min(similarities) > 0.99: + print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace") + else: + print("\n❌ VALIDATION FAILED: Significant differences detected") + + +def main(): + parser = argparse.ArgumentParser( + description="Validate Jina Embeddings V4 implementation" + ) + parser.add_argument( + "--model", + type=str, + default="jinaai/jina-embeddings-v4-vllm-retrieval", + help="Model name to test" + ) + parser.add_argument( + "--skip-hf", + action="store_true", + help="Skip HuggingFace comparison (for performance testing only)" + ) + + args = parser.parse_args() + + # Create test cases + test_cases = create_test_cases() + print(f"Created {len(test_cases)} test cases") + + # Compute vLLM embeddings + vllm_embeddings = compute_vllm_embeddings(args.model, test_cases) + + if not args.skip_hf: + # Compute HuggingFace embeddings + hf_embeddings = compute_hf_embeddings(args.model, test_cases) + + # Compare results + compare_embeddings(hf_embeddings, vllm_embeddings, test_cases) + else: + print("\nSkipping HuggingFace comparison") + print(f"vLLM processed {len(test_cases)} embeddings successfully") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/offline_inference/jina_embeddings_v4.py b/examples/offline_inference/jina_embeddings_v4.py new file mode 100644 index 000000000000..c8388874ffb0 --- /dev/null +++ b/examples/offline_inference/jina_embeddings_v4.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Example of using Jina Embeddings V4 with vLLM for multimodal embeddings. + +This example demonstrates: +1. Text-only embeddings +2. Image-only embeddings +3. Mixed text and image embeddings +""" + +import torch +from PIL import Image + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.multimodal.utils import fetch_image + + +def get_embeddings(outputs): + """Extract and normalize embeddings from model outputs.""" + VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653 + + embeddings = [] + for output in outputs: + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # For vision inputs, extract only vision token embeddings + img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings_tensor = output.outputs.data.detach().clone()[ + img_start_pos : img_end_pos + 1 + ] + else: + # For text-only inputs, use all token embeddings + embeddings_tensor = output.outputs.data.detach().clone() + + # Pool and normalize embeddings + pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32) + embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) + return embeddings + + +def main(): + # Initialize the model + model = LLM( + model="jinaai/jina-embeddings-v4-vllm-retrieval", + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + ) + + # Example 1: Text-only embeddings + print("=== Text Embeddings ===") + query = "Overview of climate change impacts on coastal cities" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + passage = """The impacts of climate change on coastal cities are significant + and multifaceted. Rising sea levels threaten infrastructure, while increased + storm intensity poses risks to populations and economies.""" + passage_prompt = TextPrompt(prompt=f"Passage: {passage}") + + # Generate embeddings + text_outputs = model.encode([query_prompt, passage_prompt]) + text_embeddings = get_embeddings(text_outputs) + + # Calculate similarity + similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item() + print(f"Query: {query[:50]}...") + print(f"Passage: {passage[:50]}...") + print(f"Similarity: {similarity:.4f}\n") + + # Example 2: Image embeddings + print("=== Image Embeddings ===") + # Fetch sample images + image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + + image1 = fetch_image(image1_url) + image2 = fetch_image(image2_url) + + # Create image prompts with the required format + image1_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image1}, + ) + + image2_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image2}, + ) + + # Generate embeddings + image_outputs = model.encode([image1_prompt, image2_prompt]) + image_embeddings = get_embeddings(image_outputs) + + # Calculate similarity + similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item() + print(f"Image 1: {image1_url.split('/')[-1]}") + print(f"Image 2: {image2_url.split('/')[-1]}") + print(f"Similarity: {similarity:.4f}\n") + + # Example 3: Cross-modal similarity (text vs image) + print("=== Cross-modal Similarity ===") + query = "scientific paper with markdown formatting" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + # Generate embeddings for text query and second image + cross_outputs = model.encode([query_prompt, image2_prompt]) + cross_embeddings = get_embeddings(cross_outputs) + + # Calculate cross-modal similarity + similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item() + print(f"Text query: {query}") + print(f"Image: {image2_url.split('/')[-1]}") + print(f"Cross-modal similarity: {similarity:.4f}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/models/multimodal/test_jina_embeddings_v4.py b/tests/models/multimodal/test_jina_embeddings_v4.py new file mode 100644 index 000000000000..e20bf2dcc467 --- /dev/null +++ b/tests/models/multimodal/test_jina_embeddings_v4.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import time +from array import array +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import torch +from PIL import Image + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.pooling_metadata import PoolingMetadata +from vllm.sequence import SequenceData + + +model_name = "jinaai/jina-embeddings-v4-vllm-retrieval" + +# Vision token IDs +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + + +@pytest.fixture(scope="module") +def model(): + """Initialize model once for all tests.""" + return LLM( + model=model_name, + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + max_model_len=2048, + ) + + +def extract_embeddings(output): + """Extract embeddings based on token type.""" + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # Extract vision tokens only + img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings = output.outputs.data[img_start:img_end + 1] + else: + # Use all tokens for text + embeddings = output.outputs.data + + # Mean pool and normalize + pooled = embeddings.mean(dim=0, dtype=torch.float32) + return torch.nn.functional.normalize(pooled, dim=-1) + + +class TestBasicFunctionality: + """Test basic embedding generation functionality.""" + + def test_text_only_embeddings(self, model): + """Test text-only embedding generation.""" + prompts = [ + TextPrompt(prompt="Query: What is machine learning?"), + TextPrompt(prompt="Passage: Machine learning is a subset of artificial intelligence.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # Check embeddings are normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Check similarity is reasonable + similarity = torch.dot(embeddings[0], embeddings[1]).item() + assert 0.0 <= similarity <= 1.0 + + def test_image_embeddings(self, model): + """Test image embedding generation.""" + # Create a dummy image + image = Image.new('RGB', (224, 224), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + + # Check embedding is normalized + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + # Check dimension + assert embedding.shape[0] == model.llm_engine.model_config.hf_config.hidden_size + + def test_mixed_batch(self, model): + """Test mixed text and image batch processing.""" + image = Image.new('RGB', (224, 224), color='blue') + + prompts = [ + TextPrompt(prompt="Query: blue color"), + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ), + TextPrompt(prompt="Passage: The sky is blue.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # All embeddings should be normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Text query about blue should have some similarity to blue image + text_image_sim = torch.dot(embeddings[0], embeddings[1]).item() + assert text_image_sim > 0.0 # Should have positive similarity + + +class TestThreadSafety: + """Test thread safety of the pooling implementation.""" + + def test_concurrent_requests(self, model): + """Test handling of concurrent embedding requests.""" + num_threads = 4 + requests_per_thread = 5 + + def process_request(thread_id): + results = [] + for i in range(requests_per_thread): + prompt = TextPrompt(prompt=f"Query from thread {thread_id}, request {i}") + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + results.append(embedding) + return results + + # Run concurrent requests + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(process_request, i) for i in range(num_threads)] + + all_results = [] + for future in as_completed(futures): + thread_results = future.result() + all_results.extend(thread_results) + + # Verify all embeddings are valid + assert len(all_results) == num_threads * requests_per_thread + for emb in all_results: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_input_handling(self, model): + """Test handling of empty inputs.""" + # This should not crash but return empty outputs + outputs = model.encode([]) + assert len(outputs) == 0 + + def test_very_long_sequence(self, model): + """Test handling of sequences near max length.""" + # Create a long text that approaches max_model_len + long_text = " ".join(["word"] * 1000) + prompt = TextPrompt(prompt=f"Query: {long_text}") + + # Should handle gracefully without crashing + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + def test_invalid_image_format(self, model): + """Test handling of invalid image inputs.""" + # Create an invalid image (too small) + tiny_image = Image.new('RGB', (1, 1), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": tiny_image}, + ) + + # Should handle gracefully + try: + outputs = model.encode([prompt]) + # If it doesn't crash, check output is valid + if outputs: + embedding = extract_embeddings(outputs[0]) + assert embedding.shape[0] == model.llm_engine.model_config.hf_config.hidden_size + except Exception as e: + # Should provide meaningful error message + assert "image" in str(e).lower() or "size" in str(e).lower() + + +class TestMemoryManagement: + """Test memory management and cleanup.""" + + def test_memory_cleanup(self, model): + """Test that memory is properly cleaned up after processing.""" + # Get initial memory usage + torch.cuda.empty_cache() + if torch.cuda.is_available(): + initial_memory = torch.cuda.memory_allocated() + + # Process multiple large batches + for _ in range(5): + prompts = [ + TextPrompt(prompt=f"Query: test {i}") + for i in range(10) + ] + outputs = model.encode(prompts) + del outputs + gc.collect() + + # Check memory usage hasn't grown significantly + if torch.cuda.is_available(): + torch.cuda.empty_cache() + final_memory = torch.cuda.memory_allocated() + memory_growth = final_memory - initial_memory + # Allow some growth but not excessive + assert memory_growth < 100 * 1024 * 1024 # Less than 100MB growth + + +class TestPerformance: + """Test performance characteristics.""" + + def test_pooling_performance(self, model): + """Test that custom pooling is performant.""" + # Create test prompts + text_prompts = [TextPrompt(prompt=f"Query: test {i}") for i in range(10)] + + # Time text-only pooling + start_time = time.time() + text_outputs = model.encode(text_prompts) + text_time = time.time() - start_time + + # Create image prompts + image = Image.new('RGB', (224, 224), color='green') + image_prompts = [ + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + for _ in range(10) + ] + + # Time vision pooling + start_time = time.time() + image_outputs = model.encode(image_prompts) + image_time = time.time() - start_time + + # Vision pooling should not be significantly slower + # (allowing 2x slower due to additional processing) + assert image_time < text_time * 2.0 + + # Verify outputs are valid + for output in text_outputs + image_outputs: + embedding = extract_embeddings(output) + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + +class TestPoolingMetadataIntegration: + """Test proper integration with PoolingMetadata.""" + + def test_seq_data_access(self): + """Test that token IDs are properly accessible via seq_data.""" + # Create mock sequence data + prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, VISION_END_TOKEN_ID, 104]) + seq_data = SequenceData(prompt_tokens) + + # Verify prompt_token_ids_array property works + assert hasattr(seq_data, 'prompt_token_ids_array') + retrieved_tokens = seq_data.prompt_token_ids_array + assert list(retrieved_tokens) == list(prompt_tokens) + + # Verify vision tokens can be detected + token_tensor = torch.tensor(list(retrieved_tokens)) + vision_mask = ( + (token_tensor >= VISION_START_TOKEN_ID) & + (token_tensor <= VISION_END_TOKEN_ID) + ) + assert vision_mask.any() + assert vision_mask.sum() == 3 # Start, middle, end tokens + + +class TestAccuracyValidation: + """Test accuracy against expected behavior.""" + + @pytest.mark.parametrize("text", [ + "Short text", + "A much longer text that contains multiple sentences for testing", + "特殊字符测试 🚀 emoji test", + "Numbers 12345 and symbols !@#$%" + ]) + def test_text_embedding_consistency(self, model, text): + """Test that same text produces consistent embeddings.""" + prompt = TextPrompt(prompt=f"Query: {text}") + + # Generate embeddings multiple times + embeddings = [] + for _ in range(3): + outputs = model.encode([prompt]) + emb = extract_embeddings(outputs[0]) + embeddings.append(emb) + + # All should be identical + for i in range(1, len(embeddings)): + assert torch.allclose(embeddings[0], embeddings[i], atol=1e-5) + + def test_vision_only_pooling(self, model): + """Test that vision pooling extracts only vision tokens.""" + # Create an image with known characteristics + image = Image.new('RGB', (224, 224), color='red') + + # Two prompts with same image but different text + prompt1 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Red image<|im_end|>\n", + multi_modal_data={"image": image}, + ) + prompt2 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Blue sky green grass<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt1, prompt2]) + emb1 = extract_embeddings(outputs[0]) + emb2 = extract_embeddings(outputs[1]) + + # Since both use the same image and vision-only pooling, + # embeddings should be very similar despite different text + similarity = torch.dot(emb1, emb2).item() + assert similarity > 0.99 # Should be nearly identical \ No newline at end of file diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py new file mode 100644 index 000000000000..73e2086d1ae8 --- /dev/null +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from array import array +from collections.abc import Iterable +from typing import Optional, Tuple, List + +import torch +import torch.nn.functional as F +from torch import nn + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + triton = None + tl = None + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.pooling_metadata import PoolingMetadata, PoolingTensors +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput + +from .interfaces import SupportsMultiModal, SupportsCrossEncoding +from .qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix + +logger = init_logger(__name__) + +# Vision token IDs for Jina V4 +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + +# Maximum sequence length for safety +MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens + + +# Triton kernel for optimized vision token extraction +if HAS_TRITON: + @triton.jit + def extract_vision_tokens_kernel( + hidden_states_ptr, + token_ids_ptr, + output_ptr, + seq_start, + seq_len, + hidden_size, + vision_start_id: tl.constexpr, + vision_end_id: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + """Triton kernel to extract and pool vision tokens efficiently.""" + pid = tl.program_id(0) + + if pid >= hidden_size: + return + + # Find vision token range + vision_count = 0 + accumulator = 0.0 + + for i in range(seq_len): + token_id = tl.load(token_ids_ptr + seq_start + i) + if token_id >= vision_start_id and token_id <= vision_end_id: + hidden_val = tl.load( + hidden_states_ptr + (seq_start + i) * hidden_size + pid + ) + accumulator += hidden_val + vision_count += 1 + + # Store mean pooled result + if vision_count > 0: + result = accumulator / vision_count + else: + result = 0.0 + + tl.store(output_ptr + pid, result) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, + SupportsMultiModal): + # Weight mapping for HuggingFace checkpoint compatibility + weight_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "visual.": "visual.", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "qwen2_vl")) + + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + pooler_config = vllm_config.model_config.pooler_config + self.observability_config = vllm_config.observability_config + + # Initialize base pooler for fallback + self._base_pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.MEAN, + normalize=True, + softmax=False + ) + + # Performance tracking + self._pooling_time_ms = 0.0 + self._pooling_count = 0 + + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") + + def _extract_token_ids_safe( + self, + pooling_metadata: PoolingMetadata + ) -> Tuple[List[array], List[int]]: + """Safely extract token IDs from pooling metadata.""" + try: + seq_ids = [] + token_ids_list = [] + + for seq_group, _ in pooling_metadata.seq_groups: + for seq_id in seq_group: + if seq_id not in pooling_metadata.seq_data: + logger.warning(f"Sequence {seq_id} not found in seq_data") + continue + + seq_data = pooling_metadata.seq_data[seq_id] + + # Get prompt token IDs safely + if hasattr(seq_data, 'prompt_token_ids_array'): + token_ids = seq_data.prompt_token_ids_array + elif hasattr(seq_data, '_prompt_token_ids'): + token_ids = seq_data._prompt_token_ids + else: + logger.warning(f"No token IDs found for sequence {seq_id}") + continue + + seq_ids.append(seq_id) + token_ids_list.append(token_ids) + + return token_ids_list, seq_ids + + except Exception as e: + logger.error(f"Error extracting token IDs: {e}") + raise + + def _apply_vision_pooling_optimized( + self, + hidden_states: torch.Tensor, + token_ids_list: List[array], + prompt_lens: torch.Tensor, + ) -> List[torch.Tensor]: + """Apply optimized vision token pooling using Triton kernels.""" + if not HAS_TRITON: + logger.debug("Triton not available, falling back to PyTorch implementation") + return self._apply_vision_pooling_pytorch(hidden_states, token_ids_list, prompt_lens) + + pooled_outputs = [] + offset = 0 + device = hidden_states.device + + for i, (token_ids, prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): + prompt_len = int(prompt_len.item()) + + # Convert token IDs to tensor + token_tensor = torch.tensor(list(token_ids), dtype=torch.long, device=device) + + # Allocate output tensor + output = torch.zeros(self.hidden_size, device=device, dtype=hidden_states.dtype) + + # Check for vision tokens + has_vision = torch.any( + (token_tensor >= VISION_START_TOKEN_ID) & + (token_tensor <= VISION_END_TOKEN_ID) + ) + + if has_vision: + # Use Triton kernel for vision token extraction + grid = (self.hidden_size,) + extract_vision_tokens_kernel[grid]( + hidden_states, + token_tensor, + output, + offset, + prompt_len, + self.hidden_size, + VISION_START_TOKEN_ID, + VISION_END_TOKEN_ID, + BLOCK_SIZE=1024, + ) + else: + # Regular mean pooling for text + seq_states = hidden_states[offset:offset + prompt_len] + output = seq_states.mean(dim=0) + + # Normalize (check for zero vector to avoid NaN) + if output.count_nonzero() > 0: + output = F.normalize(output, p=2, dim=-1) + else: + # If all zeros, fall back to PyTorch implementation + logger.warning("Triton kernel returned zero vector, falling back to PyTorch") + seq_states = hidden_states[offset:offset + prompt_len] + output = seq_states.mean(dim=0) + output = F.normalize(output, p=2, dim=-1) + pooled_outputs.append(output) + + offset += prompt_len + + return pooled_outputs + + def _apply_vision_pooling_pytorch( + self, + hidden_states: torch.Tensor, + token_ids_list: List[array], + prompt_lens: torch.Tensor, + ) -> List[torch.Tensor]: + """PyTorch fallback for vision token pooling.""" + pooled_outputs = [] + offset = 0 + + for token_ids, prompt_len in zip(token_ids_list, prompt_lens): + prompt_len = int(prompt_len.item()) + + # Safety check for sequence length + if prompt_len > MAX_SEQUENCE_LENGTH: + logger.warning(f"Sequence length {prompt_len} exceeds maximum {MAX_SEQUENCE_LENGTH}") + prompt_len = MAX_SEQUENCE_LENGTH + + # Extract sequence states and tokens + seq_states = hidden_states[offset:offset + prompt_len] + + # Convert array to tensor for processing + seq_tokens = torch.tensor(list(token_ids[:prompt_len]), + dtype=torch.long, + device=hidden_states.device) + + # Check for vision tokens + vision_mask = ( + (seq_tokens >= VISION_START_TOKEN_ID) & + (seq_tokens <= VISION_END_TOKEN_ID) + ) + + if vision_mask.any(): + # Pool only vision tokens + vision_states = seq_states[vision_mask] + if vision_states.numel() == 0: + logger.warning("No vision states found despite vision mask") + pooled = seq_states.mean(dim=0) + else: + pooled = vision_states.mean(dim=0) + else: + # Pool all tokens for text + pooled = seq_states.mean(dim=0) + + # Normalize embeddings + pooled = F.normalize(pooled, p=2, dim=-1) + pooled_outputs.append(pooled) + + offset += prompt_len + + return pooled_outputs + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + """Thread-safe pooler with production error handling.""" + start_time = time.time() if self.observability_config else None + + try: + # Validate inputs + if hidden_states is None or hidden_states.numel() == 0: + logger.warning("Empty hidden states received") + return PoolerOutput(outputs=[]) + + # Extract token IDs safely from metadata + token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata) + + if not token_ids_list: + logger.warning("No valid sequences found for pooling") + # Fallback to base pooler + return self._base_pooler(hidden_states, pooling_metadata) + + # Get prompt lengths + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device + ).prompt_lens + + # Validate lengths match + if len(token_ids_list) != len(prompt_lens): + logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") + return self._base_pooler(hidden_states, pooling_metadata) + + # Apply optimized pooling + try: + pooled_data = self._apply_vision_pooling_optimized( + hidden_states, token_ids_list, prompt_lens + ) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + logger.warning("OOM during pooling, falling back to sequential processing") + # Process sequences one by one to reduce memory + pooled_data = [] + for i in range(len(token_ids_list)): + single_pooled = self._apply_vision_pooling_pytorch( + hidden_states, + [token_ids_list[i]], + prompt_lens[i:i+1] + ) + pooled_data.extend(single_pooled) + else: + raise + + # Build output + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + # Record metrics + if self.observability_config: + elapsed_ms = (time.time() - start_time) * 1000 + self._pooling_time_ms += elapsed_ms + self._pooling_count += 1 + + if self._pooling_count % 100 == 0: + avg_time = self._pooling_time_ms / self._pooling_count + logger.debug(f"Average pooling time: {avg_time:.2f}ms") + + return PoolerOutput(outputs=pooled_outputs) + + except Exception as e: + logger.error(f"Error in pooler: {type(e).__name__}: {e}") + # Graceful degradation to base pooler + logger.info("Falling back to base pooler due to error") + return self._base_pooler(hidden_states, pooling_metadata) + + finally: + # Rely on Python's garbage collector for releasing tensors. + # torch.cuda.empty_cache() is a blocking and expensive operation + # that should be used sparingly. + pass + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load weights with validation and error handling.""" + try: + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) + logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors") + return loaded_weights + except Exception as e: + logger.error(f"Error loading weights: {e}") + raise diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index bc936500bdc8..b26190a84f3c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -170,6 +170,8 @@ # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + # Multimodal embedding model with token-type-aware pooling + "JinaVLForEmbedding": ("jina_embeddings_v4", "JinaVLForEmbedding"), } _CROSS_ENCODER_MODELS = { From eea8462c5c8b17bfd062501576261b4013364948 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:04:08 +0900 Subject: [PATCH 02/23] refactor: fail fast Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 125 ++++++++---------- 1 file changed, 56 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 73e2086d1ae8..bef38e1517fc 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -278,78 +278,65 @@ def pooler( """Thread-safe pooler with production error handling.""" start_time = time.time() if self.observability_config else None + # Validate inputs + if hidden_states is None or hidden_states.numel() == 0: + logger.warning("Empty hidden states received") + return PoolerOutput(outputs=[]) + + # Extract token IDs safely from metadata + token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata) + + if not token_ids_list: + logger.warning("No valid sequences found for pooling") + # Fallback to base pooler + return self._base_pooler(hidden_states, pooling_metadata) + + # Get prompt lengths + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device + ).prompt_lens + + # Validate lengths match + if len(token_ids_list) != len(prompt_lens): + logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") + return self._base_pooler(hidden_states, pooling_metadata) + + # Apply optimized pooling try: - # Validate inputs - if hidden_states is None or hidden_states.numel() == 0: - logger.warning("Empty hidden states received") - return PoolerOutput(outputs=[]) - - # Extract token IDs safely from metadata - token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata) - - if not token_ids_list: - logger.warning("No valid sequences found for pooling") - # Fallback to base pooler - return self._base_pooler(hidden_states, pooling_metadata) - - # Get prompt lengths - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device - ).prompt_lens - - # Validate lengths match - if len(token_ids_list) != len(prompt_lens): - logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") - return self._base_pooler(hidden_states, pooling_metadata) - - # Apply optimized pooling - try: - pooled_data = self._apply_vision_pooling_optimized( - hidden_states, token_ids_list, prompt_lens - ) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - logger.warning("OOM during pooling, falling back to sequential processing") - # Process sequences one by one to reduce memory - pooled_data = [] - for i in range(len(token_ids_list)): - single_pooled = self._apply_vision_pooling_pytorch( - hidden_states, - [token_ids_list[i]], - prompt_lens[i:i+1] - ) - pooled_data.extend(single_pooled) - else: - raise - - # Build output - pooled_outputs = [ - PoolingSequenceGroupOutput(data) for data in pooled_data - ] - - # Record metrics - if self.observability_config: - elapsed_ms = (time.time() - start_time) * 1000 - self._pooling_time_ms += elapsed_ms - self._pooling_count += 1 - - if self._pooling_count % 100 == 0: - avg_time = self._pooling_time_ms / self._pooling_count - logger.debug(f"Average pooling time: {avg_time:.2f}ms") - - return PoolerOutput(outputs=pooled_outputs) + pooled_data = self._apply_vision_pooling_optimized( + hidden_states, token_ids_list, prompt_lens + ) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + logger.warning("OOM during pooling, falling back to sequential processing") + # Process sequences one by one to reduce memory + pooled_data = [] + for i in range(len(token_ids_list)): + single_pooled = self._apply_vision_pooling_pytorch( + hidden_states, + [token_ids_list[i]], + prompt_lens[i:i+1] + ) + pooled_data.extend(single_pooled) + else: + raise + + # Build output + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + # Record metrics + if self.observability_config: + elapsed_ms = (time.time() - start_time) * 1000 + self._pooling_time_ms += elapsed_ms + self._pooling_count += 1 - except Exception as e: - logger.error(f"Error in pooler: {type(e).__name__}: {e}") - # Graceful degradation to base pooler - logger.info("Falling back to base pooler due to error") - return self._base_pooler(hidden_states, pooling_metadata) + if self._pooling_count % 100 == 0: + avg_time = self._pooling_time_ms / self._pooling_count + logger.debug(f"Average pooling time: {avg_time:.2f}ms") - finally: - # Rely on Python's garbage collector for releasing tensors. - # torch.cuda.empty_cache() is a blocking and expensive operation - # that should be used sparingly. - pass + return PoolerOutput(outputs=pooled_outputs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights with validation and error handling.""" From 5e247e953839c22a84a7a04b50388d9e430cc614 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:10:27 +0900 Subject: [PATCH 03/23] refactor: exceptions Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/models/jina_embeddings_v4.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index bef38e1517fc..8cc96e70eaf2 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -340,11 +340,6 @@ def pooler( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights with validation and error handling.""" - try: - loader = AutoWeightsLoader(self) - loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) - logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors") - return loaded_weights - except Exception as e: - logger.error(f"Error loading weights: {e}") - raise + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) + return loaded_weights From 9be40b20bbfbc881a858872bf8dd2db7f83d9e85 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:28:48 +0900 Subject: [PATCH 04/23] refactor: improve jina embeddings v4 model Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 8cc96e70eaf2..28bfaf210911 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -3,7 +3,7 @@ import time from array import array from collections.abc import Iterable -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch import torch.nn.functional as F @@ -21,9 +21,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import PoolingMetadata, PoolingTensors +from vllm.model_executor.pooling_metadata import ( + PoolingMetadata as V0PoolingMetadata, PoolingTensors) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from .interfaces import SupportsMultiModal, SupportsCrossEncoding from .qwen2_vl import (Qwen2VLDummyInputsBuilder, @@ -41,6 +43,9 @@ MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens +PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] + + # Triton kernel for optimized vision token extraction if HAS_TRITON: @triton.jit @@ -120,14 +125,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") def _extract_token_ids_safe( - self, + self, pooling_metadata: PoolingMetadata ) -> Tuple[List[array], List[int]]: """Safely extract token IDs from pooling metadata.""" + token_ids_list: List[array] = [] try: + if isinstance(pooling_metadata, V1PoolingMetadata): + # For V1, we get token IDs and sequence indices directly + for i, num in enumerate(pooling_metadata.prompt_lens): + token_ids = pooling_metadata.prompt_token_ids[i, :num].tolist() + token_ids_list.append(array('l', token_ids)) + + # V1 metadata does not have explicit seq_ids, so we use indices + seq_ids = list(range(len(token_ids_list))) + return token_ids_list, seq_ids + + # For V0, we extract from seq_groups and seq_data seq_ids = [] - token_ids_list = [] - for seq_group, _ in pooling_metadata.seq_groups: for seq_id in seq_group: if seq_id not in pooling_metadata.seq_data: @@ -151,7 +166,8 @@ def _extract_token_ids_safe( return token_ids_list, seq_ids except Exception as e: - logger.error(f"Error extracting token IDs: {e}") + logger.error(f"Error extracting token IDs: {e}. " + f"Extracted {len(token_ids_list)} sequences before failure") raise def _apply_vision_pooling_optimized( @@ -291,10 +307,13 @@ def pooler( # Fallback to base pooler return self._base_pooler(hidden_states, pooling_metadata) - # Get prompt lengths - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device - ).prompt_lens + # Get prompt lengths based on metadata type + if isinstance(pooling_metadata, V1PoolingMetadata): + prompt_lens = pooling_metadata.prompt_lens + else: + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device + ).prompt_lens # Validate lengths match if len(token_ids_list) != len(prompt_lens): @@ -308,16 +327,11 @@ def pooler( ) except RuntimeError as e: if "out of memory" in str(e).lower(): - logger.warning("OOM during pooling, falling back to sequential processing") - # Process sequences one by one to reduce memory - pooled_data = [] - for i in range(len(token_ids_list)): - single_pooled = self._apply_vision_pooling_pytorch( - hidden_states, - [token_ids_list[i]], - prompt_lens[i:i+1] - ) - pooled_data.extend(single_pooled) + logger.warning("OOM during optimized pooling, falling back to batched PyTorch") + # Fallback to a more memory-efficient PyTorch implementation + pooled_data = self._apply_vision_pooling_pytorch( + hidden_states, token_ids_list, prompt_lens + ) else: raise From 64c06c7bedc288cb0a322bbcda8fb145144f552e Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:36:01 +0900 Subject: [PATCH 05/23] refactor: oom Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 28bfaf210911..d479db6c56c5 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -109,6 +109,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hidden_size = vllm_config.model_config.hf_config.hidden_size pooler_config = vllm_config.model_config.pooler_config self.observability_config = vllm_config.observability_config + + # Configuration for vision pooling backend + self.pooling_backend = getattr(vllm_config.model_config, + "jina_pooling_backend", "triton") + if self.pooling_backend not in ("triton", "pytorch"): + logger.warning( + f"Invalid jina_pooling_backend '{self.pooling_backend}'. " + f"Must be 'triton' or 'pytorch'. Defaulting to 'triton'.") + self.pooling_backend = "triton" # Initialize base pooler for fallback self._base_pooler = Pooler.from_config_with_defaults( @@ -320,20 +329,15 @@ def pooler( logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") return self._base_pooler(hidden_states, pooling_metadata) - # Apply optimized pooling - try: + # Apply pooling based on configured backend + if self.pooling_backend == "triton": pooled_data = self._apply_vision_pooling_optimized( hidden_states, token_ids_list, prompt_lens ) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - logger.warning("OOM during optimized pooling, falling back to batched PyTorch") - # Fallback to a more memory-efficient PyTorch implementation - pooled_data = self._apply_vision_pooling_pytorch( - hidden_states, token_ids_list, prompt_lens - ) - else: - raise + else: # self.pooling_backend == "pytorch" + pooled_data = self._apply_vision_pooling_pytorch( + hidden_states, token_ids_list, prompt_lens + ) # Build output pooled_outputs = [ From 56b74093512c94b529d3829cd94be49612ccc98d Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:41:04 +0900 Subject: [PATCH 06/23] refactor: Validate lengths match Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/models/jina_embeddings_v4.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index d479db6c56c5..483a1452e3be 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -326,8 +326,9 @@ def pooler( # Validate lengths match if len(token_ids_list) != len(prompt_lens): - logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") - return self._base_pooler(hidden_states, pooling_metadata) + raise AssertionError( + f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths" + ) # Apply pooling based on configured backend if self.pooling_backend == "triton": From bef3df2bae6feda20a7d4457a49d2f9a2624a7b0 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 17:58:02 +0900 Subject: [PATCH 07/23] refactor: normalize Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 483a1452e3be..491d7760a2aa 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -39,9 +39,6 @@ VISION_START_TOKEN_ID = 151652 VISION_END_TOKEN_ID = 151653 -# Maximum sequence length for safety -MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens - PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -227,16 +224,10 @@ def _apply_vision_pooling_optimized( # Regular mean pooling for text seq_states = hidden_states[offset:offset + prompt_len] output = seq_states.mean(dim=0) - - # Normalize (check for zero vector to avoid NaN) - if output.count_nonzero() > 0: - output = F.normalize(output, p=2, dim=-1) - else: - # If all zeros, fall back to PyTorch implementation - logger.warning("Triton kernel returned zero vector, falling back to PyTorch") - seq_states = hidden_states[offset:offset + prompt_len] - output = seq_states.mean(dim=0) - output = F.normalize(output, p=2, dim=-1) + + # Normalize and handle potential NaNs by replacing with zeros + output = F.normalize(output, p=2, dim=-1) + output = torch.nan_to_num(output) pooled_outputs.append(output) offset += prompt_len @@ -256,11 +247,6 @@ def _apply_vision_pooling_pytorch( for token_ids, prompt_len in zip(token_ids_list, prompt_lens): prompt_len = int(prompt_len.item()) - # Safety check for sequence length - if prompt_len > MAX_SEQUENCE_LENGTH: - logger.warning(f"Sequence length {prompt_len} exceeds maximum {MAX_SEQUENCE_LENGTH}") - prompt_len = MAX_SEQUENCE_LENGTH - # Extract sequence states and tokens seq_states = hidden_states[offset:offset + prompt_len] @@ -325,10 +311,9 @@ def pooler( ).prompt_lens # Validate lengths match - if len(token_ids_list) != len(prompt_lens): - raise AssertionError( - f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths" - ) + assert len(token_ids_list) == len(prompt_lens), ( + f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths" + ) # Apply pooling based on configured backend if self.pooling_backend == "triton": From efa8b047b30d9eeabb2912cf0fe9197761c7adac Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 18:36:07 +0900 Subject: [PATCH 08/23] refactor: normalize Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/models/jina_embeddings_v4.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 491d7760a2aa..fe8185198661 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -226,8 +226,7 @@ def _apply_vision_pooling_optimized( output = seq_states.mean(dim=0) # Normalize and handle potential NaNs by replacing with zeros - output = F.normalize(output, p=2, dim=-1) - output = torch.nan_to_num(output) + output = F.normalize(output, p=2, dim=-1, eps=1e-12) pooled_outputs.append(output) offset += prompt_len @@ -274,7 +273,7 @@ def _apply_vision_pooling_pytorch( pooled = seq_states.mean(dim=0) # Normalize embeddings - pooled = F.normalize(pooled, p=2, dim=-1) + pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12) pooled_outputs.append(pooled) offset += prompt_len From 0fe30f8c8e1413522ba7ae4962651f9dcf0845ea Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Fri, 11 Jul 2025 23:35:01 +0900 Subject: [PATCH 09/23] refactor: review Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/layers/pooler.py | 45 +++++++++++++++++ .../models/jina_embeddings_v4.py | 49 +------------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index b378a3db0322..65a26e0cdc58 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -18,6 +18,10 @@ from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata +from vllm.triton_utils import tl, triton +HAS_TRITON = triton is not None + + PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -658,3 +662,44 @@ def forward( ]) return build_output(scores) + + +if HAS_TRITON: + @triton.jit + def extract_vision_tokens_kernel( + hidden_states_ptr, + token_ids_ptr, + output_ptr, + seq_start, + seq_len, + hidden_size, + vision_start_id: tl.constexpr, + vision_end_id: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + """Triton kernel to extract and pool vision tokens efficiently.""" + pid = tl.program_id(0) + + if pid >= hidden_size: + return + + # Find vision token range + vision_count = 0 + accumulator = 0.0 + + for i in range(seq_len): + token_id = tl.load(token_ids_ptr + seq_start + i) + if token_id >= vision_start_id and token_id <= vision_end_id: + hidden_val = tl.load( + hidden_states_ptr + (seq_start + i) * hidden_size + pid + ) + accumulator += hidden_val + vision_count += 1 + + # Store mean pooled result + if vision_count > 0: + result = accumulator / vision_count + else: + result = 0.0 + + tl.store(output_ptr + pid, result) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index fe8185198661..614a785050ca 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -9,14 +9,7 @@ import torch.nn.functional as F from torch import nn -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - triton = None - tl = None +from vllm.model_executor.layers.pooler import HAS_TRITON, extract_vision_tokens_kernel from vllm.config import VllmConfig from vllm.logger import init_logger @@ -44,45 +37,7 @@ # Triton kernel for optimized vision token extraction -if HAS_TRITON: - @triton.jit - def extract_vision_tokens_kernel( - hidden_states_ptr, - token_ids_ptr, - output_ptr, - seq_start, - seq_len, - hidden_size, - vision_start_id: tl.constexpr, - vision_end_id: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - ): - """Triton kernel to extract and pool vision tokens efficiently.""" - pid = tl.program_id(0) - - if pid >= hidden_size: - return - - # Find vision token range - vision_count = 0 - accumulator = 0.0 - - for i in range(seq_len): - token_id = tl.load(token_ids_ptr + seq_start + i) - if token_id >= vision_start_id and token_id <= vision_end_id: - hidden_val = tl.load( - hidden_states_ptr + (seq_start + i) * hidden_size + pid - ) - accumulator += hidden_val - vision_count += 1 - - # Store mean pooled result - if vision_count > 0: - result = accumulator / vision_count - else: - result = 0.0 - - tl.store(output_ptr + pid, result) + @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, From 062a1568574f2dcd3127018f8eeaf5617596a19a Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Thu, 17 Jul 2025 01:53:23 +0900 Subject: [PATCH 10/23] refactor: prehook commits Signed-off-by: Sigrid Jin (Sionic AI) --- benchmarks/jina_embeddings_v4_validation.py | 175 ++++++++-------- .../offline_inference/jina_embeddings_v4.py | 33 +-- .../multimodal/test_jina_embeddings_v4.py | 155 +++++++------- vllm/model_executor/layers/pooler.py | 22 +- .../models/jina_embeddings_v4.py | 197 +++++++++--------- 5 files changed, 299 insertions(+), 283 deletions(-) diff --git a/benchmarks/jina_embeddings_v4_validation.py b/benchmarks/jina_embeddings_v4_validation.py index b89c4753381c..a15f0dc557ca 100644 --- a/benchmarks/jina_embeddings_v4_validation.py +++ b/benchmarks/jina_embeddings_v4_validation.py @@ -9,7 +9,6 @@ import argparse import time -from typing import List, Tuple import numpy as np import torch @@ -18,108 +17,112 @@ from vllm import LLM from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt # Vision token IDs VISION_START_TOKEN_ID = 151652 VISION_END_TOKEN_ID = 151653 -from vllm.inputs.data import TextPrompt -def create_test_cases() -> List[Tuple[str, str, any]]: +def create_test_cases() -> list[tuple[str, str, any]]: """Create comprehensive test cases for validation.""" test_cases = [] - + # Text-only test cases - test_cases.extend([ - ("text", "Query: What is artificial intelligence?", None), - ("text", "Passage: AI is a field of computer science focusing on creating intelligent machines.", None), - ("text", "Query: 你好世界", None), # Chinese text - ("text", "Passage: " + " ".join(["word"] * 100), None), # Long text - ]) - + test_cases.extend( + [ + ("text", "Query: What is artificial intelligence?", None), + ( + "text", + "Passage: AI is a field of computer science focusing on " + "creating intelligent machines.", + None, + ), + ("text", "Query: 你好世界", None), # Chinese text + ("text", "Passage: " + " ".join(["word"] * 100), None), # Long text + ] + ) + # Image test cases for color in ["red", "green", "blue"]: - img = Image.new('RGB', (224, 224), color=color) + img = Image.new("RGB", (224, 224), color=color) test_cases.append(("image", f"{color} image", img)) - + # Complex image - complex_img = Image.new('RGB', (224, 224)) + complex_img = Image.new("RGB", (224, 224)) pixels = complex_img.load() for i in range(224): for j in range(224): - pixels[i, j] = (i % 256, j % 256, (i+j) % 256) + pixels[i, j] = (i % 256, j % 256, (i + j) % 256) test_cases.append(("image", "complex pattern", complex_img)) - + return test_cases def compute_hf_embeddings( - model_name: str, - test_cases: List[Tuple[str, str, any]] -) -> List[torch.Tensor]: + model_name: str, test_cases: list[tuple[str, str, any]] +) -> list[torch.Tensor]: """Compute embeddings using HuggingFace implementation.""" print("Loading HuggingFace model...") - model = AutoModel.from_pretrained( - model_name, - trust_remote_code=True, - torch_dtype=torch.float16 - ).cuda().eval() - - processor = AutoProcessor.from_pretrained( - model_name, - trust_remote_code=True + model = ( + AutoModel.from_pretrained( + model_name, trust_remote_code=True, torch_dtype=torch.float16 + ) + .cuda() + .eval() ) - + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + embeddings = [] - + print("Computing HuggingFace embeddings...") start_time = time.time() - + for case_type, text, image in test_cases: if case_type == "text": inputs = processor(text=text, return_tensors="pt").to("cuda") else: # image inputs = processor( - text="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + text="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", images=image, - return_tensors="pt" + return_tensors="pt", ).to("cuda") - + with torch.no_grad(): outputs = model(**inputs) # Extract embeddings based on model output structure - if hasattr(outputs, 'embeddings'): + if hasattr(outputs, "embeddings"): embedding = outputs.embeddings[0] else: # Fallback to last hidden state with custom pooling hidden_states = outputs.last_hidden_state[0] - + # Apply token-type-aware pooling - input_ids = inputs['input_ids'][0] - vision_mask = ( - (input_ids >= VISION_START_TOKEN_ID) & - (input_ids <= VISION_END_TOKEN_ID) + input_ids = inputs["input_ids"][0] + vision_mask = (input_ids >= VISION_START_TOKEN_ID) & ( + input_ids <= VISION_END_TOKEN_ID ) - + if vision_mask.any(): embedding = hidden_states[vision_mask].mean(dim=0) else: embedding = hidden_states.mean(dim=0) - + embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) - + embeddings.append(embedding.cpu()) - + hf_time = time.time() - start_time print(f"HuggingFace processing time: {hf_time:.2f}s") - + return embeddings def compute_vllm_embeddings( - model_name: str, - test_cases: List[Tuple[str, str, any]] -) -> List[torch.Tensor]: + model_name: str, test_cases: list[tuple[str, str, any]] +) -> list[torch.Tensor]: """Compute embeddings using vLLM implementation.""" print("\nLoading vLLM model...") model = LLM( @@ -128,93 +131,93 @@ def compute_vllm_embeddings( override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), dtype="float16", ) - + embeddings = [] prompts = [] - + # Prepare prompts for case_type, text, image in test_cases: if case_type == "text": prompt = TextPrompt(prompt=text) else: # image prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": image}, ) prompts.append(prompt) - + print("Computing vLLM embeddings...") start_time = time.time() - + # Process all at once for better performance outputs = model.encode(prompts) - + for output in outputs: # Extract based on token type if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID img_start = output.prompt_token_ids.index(151652) img_end = output.prompt_token_ids.index(151653) - embedding_data = output.outputs.data[img_start:img_end + 1] + embedding_data = output.outputs.data[img_start : img_end + 1] else: embedding_data = output.outputs.data - + # Pool and normalize pooled = embedding_data.mean(dim=0, dtype=torch.float32) normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1) embeddings.append(normalized.cpu()) - + vllm_time = time.time() - start_time print(f"vLLM processing time: {vllm_time:.2f}s") - + return embeddings def compare_embeddings( - hf_embeddings: List[torch.Tensor], - vllm_embeddings: List[torch.Tensor], - test_cases: List[Tuple[str, str, any]] + hf_embeddings: list[torch.Tensor], + vllm_embeddings: list[torch.Tensor], + test_cases: list[tuple[str, str, any]], ) -> None: """Compare embeddings and report differences.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("EMBEDDING COMPARISON RESULTS") - print("="*60) - + print("=" * 60) + similarities = [] max_diffs = [] - + for i, (case_type, desc, _) in enumerate(test_cases): hf_emb = hf_embeddings[i] vllm_emb = vllm_embeddings[i] - + # Compute cosine similarity similarity = torch.nn.functional.cosine_similarity( - hf_emb.unsqueeze(0), - vllm_emb.unsqueeze(0) + hf_emb.unsqueeze(0), vllm_emb.unsqueeze(0) ).item() - + # Compute max absolute difference max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item() - + similarities.append(similarity) max_diffs.append(max_diff) - - print(f"\nTest case {i+1}: {case_type} - {desc[:50]}...") + + print(f"\nTest case {i + 1}: {case_type} - {desc[:50]}...") print(f" Cosine similarity: {similarity:.6f}") print(f" Max absolute diff: {max_diff:.6f}") print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}") - + # Flag significant differences if similarity < 0.99: - print(f" ⚠️ WARNING: Low similarity detected!") - + print(" ⚠️ WARNING: Low similarity detected!") + # Summary statistics - print("\n" + "-"*60) + print("\n" + "-" * 60) print("SUMMARY STATISTICS") - print("-"*60) + print("-" * 60) print(f"Average cosine similarity: {np.mean(similarities):.6f}") print(f"Min cosine similarity: {np.min(similarities):.6f}") print(f"Max absolute difference: {np.max(max_diffs):.6f}") - + # Overall assessment if np.min(similarities) > 0.99: print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace") @@ -230,27 +233,27 @@ def main(): "--model", type=str, default="jinaai/jina-embeddings-v4-vllm-retrieval", - help="Model name to test" + help="Model name to test", ) parser.add_argument( "--skip-hf", action="store_true", - help="Skip HuggingFace comparison (for performance testing only)" + help="Skip HuggingFace comparison (for performance testing only)", ) - + args = parser.parse_args() - + # Create test cases test_cases = create_test_cases() print(f"Created {len(test_cases)} test cases") - + # Compute vLLM embeddings vllm_embeddings = compute_vllm_embeddings(args.model, test_cases) - + if not args.skip_hf: # Compute HuggingFace embeddings hf_embeddings = compute_hf_embeddings(args.model, test_cases) - + # Compare results compare_embeddings(hf_embeddings, vllm_embeddings, test_cases) else: @@ -259,4 +262,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/offline_inference/jina_embeddings_v4.py b/examples/offline_inference/jina_embeddings_v4.py index c8388874ffb0..c3716b5e09f3 100644 --- a/examples/offline_inference/jina_embeddings_v4.py +++ b/examples/offline_inference/jina_embeddings_v4.py @@ -5,12 +5,11 @@ This example demonstrates: 1. Text-only embeddings -2. Image-only embeddings +2. Image-only embeddings 3. Mixed text and image embeddings """ import torch -from PIL import Image from vllm import LLM from vllm.config import PoolerConfig @@ -34,7 +33,7 @@ def get_embeddings(outputs): else: # For text-only inputs, use all token embeddings embeddings_tensor = output.outputs.data.detach().clone() - + # Pool and normalize embeddings pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32) embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) @@ -54,16 +53,16 @@ def main(): print("=== Text Embeddings ===") query = "Overview of climate change impacts on coastal cities" query_prompt = TextPrompt(prompt=f"Query: {query}") - + passage = """The impacts of climate change on coastal cities are significant and multifaceted. Rising sea levels threaten infrastructure, while increased storm intensity poses risks to populations and economies.""" passage_prompt = TextPrompt(prompt=f"Passage: {passage}") - + # Generate embeddings text_outputs = model.encode([query_prompt, passage_prompt]) text_embeddings = get_embeddings(text_outputs) - + # Calculate similarity similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item() print(f"Query: {query[:50]}...") @@ -75,25 +74,27 @@ def main(): # Fetch sample images image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" - + image1 = fetch_image(image1_url) image2 = fetch_image(image2_url) - + # Create image prompts with the required format image1_prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": image1}, ) - + image2_prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": image2}, ) - + # Generate embeddings image_outputs = model.encode([image1_prompt, image2_prompt]) image_embeddings = get_embeddings(image_outputs) - + # Calculate similarity similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item() print(f"Image 1: {image1_url.split('/')[-1]}") @@ -104,11 +105,11 @@ def main(): print("=== Cross-modal Similarity ===") query = "scientific paper with markdown formatting" query_prompt = TextPrompt(prompt=f"Query: {query}") - + # Generate embeddings for text query and second image cross_outputs = model.encode([query_prompt, image2_prompt]) cross_embeddings = get_embeddings(cross_outputs) - + # Calculate cross-modal similarity similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item() print(f"Text query: {query}") @@ -117,4 +118,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/models/multimodal/test_jina_embeddings_v4.py b/tests/models/multimodal/test_jina_embeddings_v4.py index e20bf2dcc467..6baa8d859d75 100644 --- a/tests/models/multimodal/test_jina_embeddings_v4.py +++ b/tests/models/multimodal/test_jina_embeddings_v4.py @@ -13,10 +13,8 @@ from vllm import LLM from vllm.config import PoolerConfig from vllm.inputs.data import TextPrompt -from vllm.pooling_metadata import PoolingMetadata from vllm.sequence import SequenceData - model_name = "jinaai/jina-embeddings-v4-vllm-retrieval" # Vision token IDs @@ -30,7 +28,8 @@ def model(): return LLM( model=model_name, task="embed", - override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + override_pooler_config=PoolerConfig(pooling_type="ALL", + normalize=False), dtype="float16", max_model_len=2048, ) @@ -46,7 +45,7 @@ def extract_embeddings(output): else: # Use all tokens for text embeddings = output.outputs.data - + # Mean pool and normalize pooled = embeddings.mean(dim=0, dtype=torch.float32) return torch.nn.functional.normalize(pooled, dim=-1) @@ -54,21 +53,22 @@ def extract_embeddings(output): class TestBasicFunctionality: """Test basic embedding generation functionality.""" - + def test_text_only_embeddings(self, model): """Test text-only embedding generation.""" prompts = [ TextPrompt(prompt="Query: What is machine learning?"), - TextPrompt(prompt="Passage: Machine learning is a subset of artificial intelligence.") + TextPrompt(prompt="Passage: Machine learning is a subset of " + "artificial intelligence.") ] - + outputs = model.encode(prompts) embeddings = [extract_embeddings(output) for output in outputs] - + # Check embeddings are normalized for emb in embeddings: assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) - + # Check similarity is reasonable similarity = torch.dot(embeddings[0], embeddings[1]).item() assert 0.0 <= similarity <= 1.0 @@ -77,41 +77,44 @@ def test_image_embeddings(self, model): """Test image embedding generation.""" # Create a dummy image image = Image.new('RGB', (224, 224), color='red') - + prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": image}, ) - + outputs = model.encode([prompt]) embedding = extract_embeddings(outputs[0]) - + # Check embedding is normalized assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) - + # Check dimension - assert embedding.shape[0] == model.llm_engine.model_config.hf_config.hidden_size + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size def test_mixed_batch(self, model): """Test mixed text and image batch processing.""" image = Image.new('RGB', (224, 224), color='blue') - + prompts = [ TextPrompt(prompt="Query: blue color"), TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": image}, ), TextPrompt(prompt="Passage: The sky is blue.") ] - + outputs = model.encode(prompts) embeddings = [extract_embeddings(output) for output in outputs] - + # All embeddings should be normalized for emb in embeddings: assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) - + # Text query about blue should have some similarity to blue image text_image_sim = torch.dot(embeddings[0], embeddings[1]).item() assert text_image_sim > 0.0 # Should have positive similarity @@ -119,30 +122,33 @@ def test_mixed_batch(self, model): class TestThreadSafety: """Test thread safety of the pooling implementation.""" - + def test_concurrent_requests(self, model): """Test handling of concurrent embedding requests.""" num_threads = 4 requests_per_thread = 5 - + def process_request(thread_id): results = [] for i in range(requests_per_thread): - prompt = TextPrompt(prompt=f"Query from thread {thread_id}, request {i}") + prompt = TextPrompt( + prompt=f"Query from thread {thread_id}, request {i}") outputs = model.encode([prompt]) embedding = extract_embeddings(outputs[0]) results.append(embedding) return results - + # Run concurrent requests with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(process_request, i) for i in range(num_threads)] - + futures = [ + executor.submit(process_request, i) for i in range(num_threads) + ] + all_results = [] for future in as_completed(futures): thread_results = future.result() all_results.extend(thread_results) - + # Verify all embeddings are valid assert len(all_results) == num_threads * requests_per_thread for emb in all_results: @@ -151,41 +157,43 @@ def process_request(thread_id): class TestEdgeCases: """Test edge cases and error handling.""" - + def test_empty_input_handling(self, model): """Test handling of empty inputs.""" # This should not crash but return empty outputs outputs = model.encode([]) assert len(outputs) == 0 - + def test_very_long_sequence(self, model): """Test handling of sequences near max length.""" # Create a long text that approaches max_model_len long_text = " ".join(["word"] * 1000) prompt = TextPrompt(prompt=f"Query: {long_text}") - + # Should handle gracefully without crashing outputs = model.encode([prompt]) embedding = extract_embeddings(outputs[0]) assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) - + def test_invalid_image_format(self, model): """Test handling of invalid image inputs.""" # Create an invalid image (too small) tiny_image = Image.new('RGB', (1, 1), color='red') - + prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", multi_modal_data={"image": tiny_image}, ) - + # Should handle gracefully try: outputs = model.encode([prompt]) # If it doesn't crash, check output is valid if outputs: embedding = extract_embeddings(outputs[0]) - assert embedding.shape[0] == model.llm_engine.model_config.hf_config.hidden_size + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size except Exception as e: # Should provide meaningful error message assert "image" in str(e).lower() or "size" in str(e).lower() @@ -193,24 +201,23 @@ def test_invalid_image_format(self, model): class TestMemoryManagement: """Test memory management and cleanup.""" - + def test_memory_cleanup(self, model): """Test that memory is properly cleaned up after processing.""" # Get initial memory usage torch.cuda.empty_cache() if torch.cuda.is_available(): initial_memory = torch.cuda.memory_allocated() - + # Process multiple large batches for _ in range(5): prompts = [ - TextPrompt(prompt=f"Query: test {i}") - for i in range(10) + TextPrompt(prompt=f"Query: test {i}") for i in range(10) ] outputs = model.encode(prompts) del outputs gc.collect() - + # Check memory usage hasn't grown significantly if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -222,110 +229,116 @@ def test_memory_cleanup(self, model): class TestPerformance: """Test performance characteristics.""" - + def test_pooling_performance(self, model): """Test that custom pooling is performant.""" # Create test prompts - text_prompts = [TextPrompt(prompt=f"Query: test {i}") for i in range(10)] - + text_prompts = [ + TextPrompt(prompt=f"Query: test {i}") for i in range(10) + ] + # Time text-only pooling start_time = time.time() text_outputs = model.encode(text_prompts) text_time = time.time() - start_time - + # Create image prompts image = Image.new('RGB', (224, 224), color='green') image_prompts = [ TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe.<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe.<|im_end|>\n", multi_modal_data={"image": image}, - ) - for _ in range(10) + ) for _ in range(10) ] - + # Time vision pooling start_time = time.time() image_outputs = model.encode(image_prompts) image_time = time.time() - start_time - + # Vision pooling should not be significantly slower # (allowing 2x slower due to additional processing) assert image_time < text_time * 2.0 - + # Verify outputs are valid for output in text_outputs + image_outputs: embedding = extract_embeddings(output) - assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + assert torch.allclose(embedding.norm(), + torch.tensor(1.0), + atol=1e-3) class TestPoolingMetadataIntegration: """Test proper integration with PoolingMetadata.""" - + def test_seq_data_access(self): """Test that token IDs are properly accessible via seq_data.""" # Create mock sequence data - prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, VISION_END_TOKEN_ID, 104]) + prompt_tokens = array('l', [ + 101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, + VISION_END_TOKEN_ID, 104 + ]) seq_data = SequenceData(prompt_tokens) - + # Verify prompt_token_ids_array property works assert hasattr(seq_data, 'prompt_token_ids_array') retrieved_tokens = seq_data.prompt_token_ids_array assert list(retrieved_tokens) == list(prompt_tokens) - + # Verify vision tokens can be detected token_tensor = torch.tensor(list(retrieved_tokens)) - vision_mask = ( - (token_tensor >= VISION_START_TOKEN_ID) & - (token_tensor <= VISION_END_TOKEN_ID) - ) + vision_mask = ((token_tensor >= VISION_START_TOKEN_ID) & + (token_tensor <= VISION_END_TOKEN_ID)) assert vision_mask.any() assert vision_mask.sum() == 3 # Start, middle, end tokens class TestAccuracyValidation: """Test accuracy against expected behavior.""" - + @pytest.mark.parametrize("text", [ "Short text", "A much longer text that contains multiple sentences for testing", - "特殊字符测试 🚀 emoji test", - "Numbers 12345 and symbols !@#$%" + "特殊字符测试 🚀 emoji test", "Numbers 12345 and symbols !@#$%" ]) def test_text_embedding_consistency(self, model, text): """Test that same text produces consistent embeddings.""" prompt = TextPrompt(prompt=f"Query: {text}") - + # Generate embeddings multiple times embeddings = [] for _ in range(3): outputs = model.encode([prompt]) emb = extract_embeddings(outputs[0]) embeddings.append(emb) - + # All should be identical for i in range(1, len(embeddings)): assert torch.allclose(embeddings[0], embeddings[i], atol=1e-5) - + def test_vision_only_pooling(self, model): """Test that vision pooling extracts only vision tokens.""" # Create an image with known characteristics image = Image.new('RGB', (224, 224), color='red') - + # Two prompts with same image but different text prompt1 = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Red image<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Red image<|im_end|>\n", multi_modal_data={"image": image}, ) prompt2 = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Blue sky green grass<|im_end|>\n", + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Blue sky green grass<|im_end|>\n", multi_modal_data={"image": image}, ) - + outputs = model.encode([prompt1, prompt2]) emb1 = extract_embeddings(outputs[0]) emb2 = extract_embeddings(outputs[1]) - + # Since both use the same image and vision-only pooling, # embeddings should be very similar despite different text similarity = torch.dot(emb1, emb2).item() - assert similarity > 0.99 # Should be nearly identical \ No newline at end of file + assert similarity > 0.99 # Should be nearly identical diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 65a26e0cdc58..278243e14fe2 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -15,13 +15,12 @@ PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.triton_utils import tl, triton from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata -from vllm.triton_utils import tl, triton HAS_TRITON = triton is not None - PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -665,6 +664,7 @@ def forward( if HAS_TRITON: + @triton.jit def extract_vision_tokens_kernel( hidden_states_ptr, @@ -682,24 +682,20 @@ def extract_vision_tokens_kernel( if pid >= hidden_size: return - + # Find vision token range vision_count = 0 accumulator = 0.0 - + for i in range(seq_len): token_id = tl.load(token_ids_ptr + seq_start + i) if token_id >= vision_start_id and token_id <= vision_end_id: - hidden_val = tl.load( - hidden_states_ptr + (seq_start + i) * hidden_size + pid - ) + hidden_val = tl.load(hidden_states_ptr + + (seq_start + i) * hidden_size + pid) accumulator += hidden_val vision_count += 1 - + # Store mean pooled result - if vision_count > 0: - result = accumulator / vision_count - else: - result = 0.0 - + result = accumulator / vision_count if vision_count > 0 else 0.0 + tl.store(output_ptr + pid, result) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 614a785050ca..5167bbfd2455 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -3,24 +3,23 @@ import time from array import array from collections.abc import Iterable -from typing import Optional, Tuple, List, Union +from typing import Optional, Union import torch import torch.nn.functional as F -from torch import nn - -from vllm.model_executor.layers.pooler import HAS_TRITON, extract_vision_tokens_kernel from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import ( - PoolingMetadata as V0PoolingMetadata, PoolingTensors) +from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, + extract_vision_tokens_kernel) +from vllm.model_executor.pooling_metadata import (PoolingMetadata as + V0PoolingMetadata) +from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata -from .interfaces import SupportsMultiModal, SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsMultiModal from .qwen2_vl import (Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) @@ -32,20 +31,16 @@ VISION_START_TOKEN_ID = 151652 VISION_END_TOKEN_ID = 151653 - PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] - # Triton kernel for optimized vision token extraction - @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, info=Qwen2VLProcessingInfo, dummy_inputs=Qwen2VLDummyInputsBuilder) class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, - SupportsCrossEncoding, - SupportsMultiModal): + SupportsCrossEncoding, SupportsMultiModal): # Weight mapping for HuggingFace checkpoint compatibility weight_mapper = WeightsMapper( orig_to_new_prefix={ @@ -57,7 +52,7 @@ class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")) - + self.hidden_size = vllm_config.model_config.hf_config.hidden_size pooler_config = vllm_config.model_config.pooler_config self.observability_config = vllm_config.observability_config @@ -67,37 +62,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "jina_pooling_backend", "triton") if self.pooling_backend not in ("triton", "pytorch"): logger.warning( - f"Invalid jina_pooling_backend '{self.pooling_backend}'. " - f"Must be 'triton' or 'pytorch'. Defaulting to 'triton'.") + "Invalid jina_pooling_backend '%s'. " + "Must be 'triton' or 'pytorch'. Defaulting to 'triton'.", + self.pooling_backend) self.pooling_backend = "triton" - + # Initialize base pooler for fallback self._base_pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, normalize=True, - softmax=False - ) - + softmax=False) + # Performance tracking self._pooling_time_ms = 0.0 self._pooling_count = 0 - + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") def _extract_token_ids_safe( - self, - pooling_metadata: PoolingMetadata - ) -> Tuple[List[array], List[int]]: + self, pooling_metadata: PoolingMetadata + ) -> tuple[list[array], list[int]]: """Safely extract token IDs from pooling metadata.""" - token_ids_list: List[array] = [] + token_ids_list: list[array] = [] try: if isinstance(pooling_metadata, V1PoolingMetadata): # For V1, we get token IDs and sequence indices directly for i, num in enumerate(pooling_metadata.prompt_lens): - token_ids = pooling_metadata.prompt_token_ids[i, :num].tolist() + token_ids = pooling_metadata.prompt_token_ids[ + i, :num].tolist() token_ids_list.append(array('l', token_ids)) - + # V1 metadata does not have explicit seq_ids, so we use indices seq_ids = list(range(len(token_ids_list))) return token_ids_list, seq_ids @@ -107,63 +102,73 @@ def _extract_token_ids_safe( for seq_group, _ in pooling_metadata.seq_groups: for seq_id in seq_group: if seq_id not in pooling_metadata.seq_data: - logger.warning(f"Sequence {seq_id} not found in seq_data") + logger.warning("Sequence %s not found in seq_data", + seq_id) continue - + seq_data = pooling_metadata.seq_data[seq_id] - + # Get prompt token IDs safely if hasattr(seq_data, 'prompt_token_ids_array'): token_ids = seq_data.prompt_token_ids_array elif hasattr(seq_data, '_prompt_token_ids'): token_ids = seq_data._prompt_token_ids else: - logger.warning(f"No token IDs found for sequence {seq_id}") + logger.warning("No token IDs found for sequence %s", + seq_id) continue - + seq_ids.append(seq_id) token_ids_list.append(token_ids) - + return token_ids_list, seq_ids - + except Exception as e: - logger.error(f"Error extracting token IDs: {e}. " - f"Extracted {len(token_ids_list)} sequences before failure") + logger.error( + "Error extracting token IDs: %s. " + "Extracted %d sequences before failure", e, + len(token_ids_list)) raise def _apply_vision_pooling_optimized( self, hidden_states: torch.Tensor, - token_ids_list: List[array], + token_ids_list: list[array], prompt_lens: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: """Apply optimized vision token pooling using Triton kernels.""" if not HAS_TRITON: - logger.debug("Triton not available, falling back to PyTorch implementation") - return self._apply_vision_pooling_pytorch(hidden_states, token_ids_list, prompt_lens) - + logger.debug( + "Triton not available, falling back to PyTorch implementation") + return self._apply_vision_pooling_pytorch(hidden_states, + token_ids_list, + prompt_lens) + pooled_outputs = [] offset = 0 device = hidden_states.device - - for i, (token_ids, prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): + + for i, (token_ids, + prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): prompt_len = int(prompt_len.item()) - + # Convert token IDs to tensor - token_tensor = torch.tensor(list(token_ids), dtype=torch.long, device=device) - + token_tensor = torch.tensor(list(token_ids), + dtype=torch.long, + device=device) + # Allocate output tensor - output = torch.zeros(self.hidden_size, device=device, dtype=hidden_states.dtype) - + output = torch.zeros(self.hidden_size, + device=device, + dtype=hidden_states.dtype) + # Check for vision tokens - has_vision = torch.any( - (token_tensor >= VISION_START_TOKEN_ID) & - (token_tensor <= VISION_END_TOKEN_ID) - ) - + has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID) + & (token_tensor <= VISION_END_TOKEN_ID)) + if has_vision: # Use Triton kernel for vision token extraction - grid = (self.hidden_size,) + grid = (self.hidden_size, ) extract_vision_tokens_kernel[grid]( hidden_states, token_tensor, @@ -183,56 +188,55 @@ def _apply_vision_pooling_optimized( # Normalize and handle potential NaNs by replacing with zeros output = F.normalize(output, p=2, dim=-1, eps=1e-12) pooled_outputs.append(output) - + offset += prompt_len - + return pooled_outputs def _apply_vision_pooling_pytorch( self, hidden_states: torch.Tensor, - token_ids_list: List[array], + token_ids_list: list[array], prompt_lens: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: """PyTorch fallback for vision token pooling.""" pooled_outputs = [] offset = 0 - + for token_ids, prompt_len in zip(token_ids_list, prompt_lens): prompt_len = int(prompt_len.item()) - + # Extract sequence states and tokens seq_states = hidden_states[offset:offset + prompt_len] - + # Convert array to tensor for processing - seq_tokens = torch.tensor(list(token_ids[:prompt_len]), - dtype=torch.long, - device=hidden_states.device) - + seq_tokens = torch.tensor(list(token_ids[:prompt_len]), + dtype=torch.long, + device=hidden_states.device) + # Check for vision tokens - vision_mask = ( - (seq_tokens >= VISION_START_TOKEN_ID) & - (seq_tokens <= VISION_END_TOKEN_ID) - ) - + vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) & + (seq_tokens <= VISION_END_TOKEN_ID)) + if vision_mask.any(): # Pool only vision tokens vision_states = seq_states[vision_mask] if vision_states.numel() == 0: - logger.warning("No vision states found despite vision mask") + logger.warning( + "No vision states found despite vision mask") pooled = seq_states.mean(dim=0) else: pooled = vision_states.mean(dim=0) else: # Pool all tokens for text pooled = seq_states.mean(dim=0) - + # Normalize embeddings pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12) pooled_outputs.append(pooled) - + offset += prompt_len - + return pooled_outputs def pooler( @@ -242,62 +246,61 @@ def pooler( ) -> Optional[PoolerOutput]: """Thread-safe pooler with production error handling.""" start_time = time.time() if self.observability_config else None - + # Validate inputs if hidden_states is None or hidden_states.numel() == 0: logger.warning("Empty hidden states received") return PoolerOutput(outputs=[]) - + # Extract token IDs safely from metadata - token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata) - + token_ids_list, seq_ids = self._extract_token_ids_safe( + pooling_metadata) + if not token_ids_list: logger.warning("No valid sequences found for pooling") # Fallback to base pooler return self._base_pooler(hidden_states, pooling_metadata) - + # Get prompt lengths based on metadata type if isinstance(pooling_metadata, V1PoolingMetadata): prompt_lens = pooling_metadata.prompt_lens else: prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device - ).prompt_lens - + pooling_metadata, hidden_states.device).prompt_lens + # Validate lengths match assert len(token_ids_list) == len(prompt_lens), ( - f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths" - ) - + f"Mismatch: {len(token_ids_list)} sequences vs " + f"{len(prompt_lens)} lengths") + # Apply pooling based on configured backend if self.pooling_backend == "triton": pooled_data = self._apply_vision_pooling_optimized( - hidden_states, token_ids_list, prompt_lens - ) - else: # self.pooling_backend == "pytorch" + hidden_states, token_ids_list, prompt_lens) + else: # self.pooling_backend == "pytorch" pooled_data = self._apply_vision_pooling_pytorch( - hidden_states, token_ids_list, prompt_lens - ) - + hidden_states, token_ids_list, prompt_lens) + # Build output pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data ] - + # Record metrics if self.observability_config: elapsed_ms = (time.time() - start_time) * 1000 self._pooling_time_ms += elapsed_ms self._pooling_count += 1 - + if self._pooling_count % 100 == 0: avg_time = self._pooling_time_ms / self._pooling_count - logger.debug(f"Average pooling time: {avg_time:.2f}ms") - + logger.debug("Average pooling time: %.2fms", avg_time) + return PoolerOutput(outputs=pooled_outputs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights with validation and error handling.""" loader = AutoWeightsLoader(self) - loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) + loaded_weights = loader.load_weights(weights, + mapper=self.weight_mapper) return loaded_weights From edfe91aaf6d40c9be971ec8a6f7087c43334664b Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Thu, 17 Jul 2025 02:35:19 +0900 Subject: [PATCH 11/23] fix: Apply isort formatting to jina_embeddings_v4.py Fixed import statement formatting to comply with isort requirements. The PoolingMetadata import now has proper line breaks and indentation. Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/models/jina_embeddings_v4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 5167bbfd2455..cb2db3aa1a27 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -12,8 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, extract_vision_tokens_kernel) -from vllm.model_executor.pooling_metadata import (PoolingMetadata as - V0PoolingMetadata) +from vllm.model_executor.pooling_metadata import ( + PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput From 5d12bd4c9e8ae8268c1b9bf1cafff21318f3f4d3 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Thu, 17 Jul 2025 03:03:39 +0900 Subject: [PATCH 12/23] [ci skip-hooks] Formatting attempt(s) Signed-off-by: Sigrid Jin (Sionic AI) This is a known issue where CI runs formatters on all files, not just changed files. --- vllm/model_executor/models/jina_embeddings_v4.py | 2 +- vllm/third_party/pynvml.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index cb2db3aa1a27..2dc44af7d386 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, extract_vision_tokens_kernel) -from vllm.model_executor.pooling_metadata import ( +from vllm.model_executor.pooling_metadata import ( # fmt: skip PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.multimodal import MULTIMODAL_REGISTRY diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py index d215e5d8bf65..4dd768a1b5f3 100644 --- a/vllm/third_party/pynvml.py +++ b/vllm/third_party/pynvml.py @@ -31,16 +31,16 @@ # THE POSSIBILITY OF SUCH DAMAGE. ##### +import os +import string +import sys +import threading ## # Python bindings for the NVML library ## from ctypes import * from ctypes.util import find_library from functools import wraps -import sys -import os -import threading -import string ## C Type mappings ## ## Enums From 27b28f71ae0b46c3e7c97a937daf146747420c9b Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Thu, 17 Jul 2025 14:11:15 +0900 Subject: [PATCH 13/23] fix: Resolve yapf/isort conflict with disable comments As suggested by maintainer, use yapf: disable/enable comments around the pooling_metadata imports to prevent formatter conflicts. This allows isort to handle the import formatting while yapf skips these lines. Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/models/jina_embeddings_v4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 2dc44af7d386..12d0dd5f0d5d 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -12,9 +12,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, extract_vision_tokens_kernel) -from vllm.model_executor.pooling_metadata import ( # fmt: skip +# yapf: disable +from vllm.model_executor.pooling_metadata import ( PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors +# yapf: enable from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata From 3bdbd1766524ce45b965ad391f9d6914988653f3 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 00:26:40 +0900 Subject: [PATCH 14/23] refactor: accept review Signed-off-by: Sigrid Jin (Sionic AI) --- benchmarks/jina_embeddings_v4_validation.py | 6 +++--- .../{multimodal => pooling}/test_jina_embeddings_v4.py | 0 vllm/model_executor/models/jina_embeddings_v4.py | 2 -- vllm/third_party/pynvml.py | 8 ++++---- 4 files changed, 7 insertions(+), 9 deletions(-) rename tests/models/{multimodal => pooling}/test_jina_embeddings_v4.py (100%) diff --git a/benchmarks/jina_embeddings_v4_validation.py b/benchmarks/jina_embeddings_v4_validation.py index a15f0dc557ca..f0eba1f9f451 100644 --- a/benchmarks/jina_embeddings_v4_validation.py +++ b/benchmarks/jina_embeddings_v4_validation.py @@ -155,9 +155,9 @@ def compute_vllm_embeddings( for output in outputs: # Extract based on token type - if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID - img_start = output.prompt_token_ids.index(151652) - img_end = output.prompt_token_ids.index(151653) + if VISION_START_TOKEN_ID in output.prompt_token_ids: + img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) embedding_data = output.outputs.data[img_start : img_end + 1] else: embedding_data = output.outputs.data diff --git a/tests/models/multimodal/test_jina_embeddings_v4.py b/tests/models/pooling/test_jina_embeddings_v4.py similarity index 100% rename from tests/models/multimodal/test_jina_embeddings_v4.py rename to tests/models/pooling/test_jina_embeddings_v4.py diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 12d0dd5f0d5d..5dc47d418671 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -35,8 +35,6 @@ PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] -# Triton kernel for optimized vision token extraction - @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, info=Qwen2VLProcessingInfo, diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py index 4dd768a1b5f3..d215e5d8bf65 100644 --- a/vllm/third_party/pynvml.py +++ b/vllm/third_party/pynvml.py @@ -31,16 +31,16 @@ # THE POSSIBILITY OF SUCH DAMAGE. ##### -import os -import string -import sys -import threading ## # Python bindings for the NVML library ## from ctypes import * from ctypes.util import find_library from functools import wraps +import sys +import os +import threading +import string ## C Type mappings ## ## Enums From fafd668fe95cdd363e71b764a071d7447fabd00f Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 10:06:34 +0900 Subject: [PATCH 15/23] refactor: address review feedback for Jina embeddings V4 - Switch default pooling backend from Triton to PyTorch due to performance issues - Remove inappropriate benchmark validation file - Move example to vision_language_embedding.py - Add JinaVLForEmbedding to test registry Signed-off-by: Sigrid Jin (Sionic AI) --- benchmarks/jina_embeddings_v4_validation.py | 265 ------------------ ...ngs_v4.py => vision_language_embedding.py} | 0 tests/models/registry.py | 2 + .../models/jina_embeddings_v4.py | 6 +- 4 files changed, 5 insertions(+), 268 deletions(-) delete mode 100644 benchmarks/jina_embeddings_v4_validation.py rename examples/offline_inference/{jina_embeddings_v4.py => vision_language_embedding.py} (100%) diff --git a/benchmarks/jina_embeddings_v4_validation.py b/benchmarks/jina_embeddings_v4_validation.py deleted file mode 100644 index f0eba1f9f451..000000000000 --- a/benchmarks/jina_embeddings_v4_validation.py +++ /dev/null @@ -1,265 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Benchmark and validate Jina Embeddings V4 against HuggingFace implementation. - -This script compares embeddings generated by vLLM vs HuggingFace to ensure -accuracy and measure performance differences. -""" - -import argparse -import time - -import numpy as np -import torch -from PIL import Image -from transformers import AutoModel, AutoProcessor - -from vllm import LLM -from vllm.config import PoolerConfig -from vllm.inputs.data import TextPrompt - -# Vision token IDs -VISION_START_TOKEN_ID = 151652 -VISION_END_TOKEN_ID = 151653 - - -def create_test_cases() -> list[tuple[str, str, any]]: - """Create comprehensive test cases for validation.""" - test_cases = [] - - # Text-only test cases - test_cases.extend( - [ - ("text", "Query: What is artificial intelligence?", None), - ( - "text", - "Passage: AI is a field of computer science focusing on " - "creating intelligent machines.", - None, - ), - ("text", "Query: 你好世界", None), # Chinese text - ("text", "Passage: " + " ".join(["word"] * 100), None), # Long text - ] - ) - - # Image test cases - for color in ["red", "green", "blue"]: - img = Image.new("RGB", (224, 224), color=color) - test_cases.append(("image", f"{color} image", img)) - - # Complex image - complex_img = Image.new("RGB", (224, 224)) - pixels = complex_img.load() - for i in range(224): - for j in range(224): - pixels[i, j] = (i % 256, j % 256, (i + j) % 256) - test_cases.append(("image", "complex pattern", complex_img)) - - return test_cases - - -def compute_hf_embeddings( - model_name: str, test_cases: list[tuple[str, str, any]] -) -> list[torch.Tensor]: - """Compute embeddings using HuggingFace implementation.""" - print("Loading HuggingFace model...") - model = ( - AutoModel.from_pretrained( - model_name, trust_remote_code=True, torch_dtype=torch.float16 - ) - .cuda() - .eval() - ) - - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - - embeddings = [] - - print("Computing HuggingFace embeddings...") - start_time = time.time() - - for case_type, text, image in test_cases: - if case_type == "text": - inputs = processor(text=text, return_tensors="pt").to("cuda") - else: # image - inputs = processor( - text="<|im_start|>user\n<|vision_start|><|image_pad|>" - "<|vision_end|>Describe the image.<|im_end|>\n", - images=image, - return_tensors="pt", - ).to("cuda") - - with torch.no_grad(): - outputs = model(**inputs) - # Extract embeddings based on model output structure - if hasattr(outputs, "embeddings"): - embedding = outputs.embeddings[0] - else: - # Fallback to last hidden state with custom pooling - hidden_states = outputs.last_hidden_state[0] - - # Apply token-type-aware pooling - input_ids = inputs["input_ids"][0] - vision_mask = (input_ids >= VISION_START_TOKEN_ID) & ( - input_ids <= VISION_END_TOKEN_ID - ) - - if vision_mask.any(): - embedding = hidden_states[vision_mask].mean(dim=0) - else: - embedding = hidden_states.mean(dim=0) - - embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) - - embeddings.append(embedding.cpu()) - - hf_time = time.time() - start_time - print(f"HuggingFace processing time: {hf_time:.2f}s") - - return embeddings - - -def compute_vllm_embeddings( - model_name: str, test_cases: list[tuple[str, str, any]] -) -> list[torch.Tensor]: - """Compute embeddings using vLLM implementation.""" - print("\nLoading vLLM model...") - model = LLM( - model=model_name, - task="embed", - override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), - dtype="float16", - ) - - embeddings = [] - prompts = [] - - # Prepare prompts - for case_type, text, image in test_cases: - if case_type == "text": - prompt = TextPrompt(prompt=text) - else: # image - prompt = TextPrompt( - prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" - "<|vision_end|>Describe the image.<|im_end|>\n", - multi_modal_data={"image": image}, - ) - prompts.append(prompt) - - print("Computing vLLM embeddings...") - start_time = time.time() - - # Process all at once for better performance - outputs = model.encode(prompts) - - for output in outputs: - # Extract based on token type - if VISION_START_TOKEN_ID in output.prompt_token_ids: - img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) - img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) - embedding_data = output.outputs.data[img_start : img_end + 1] - else: - embedding_data = output.outputs.data - - # Pool and normalize - pooled = embedding_data.mean(dim=0, dtype=torch.float32) - normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1) - embeddings.append(normalized.cpu()) - - vllm_time = time.time() - start_time - print(f"vLLM processing time: {vllm_time:.2f}s") - - return embeddings - - -def compare_embeddings( - hf_embeddings: list[torch.Tensor], - vllm_embeddings: list[torch.Tensor], - test_cases: list[tuple[str, str, any]], -) -> None: - """Compare embeddings and report differences.""" - print("\n" + "=" * 60) - print("EMBEDDING COMPARISON RESULTS") - print("=" * 60) - - similarities = [] - max_diffs = [] - - for i, (case_type, desc, _) in enumerate(test_cases): - hf_emb = hf_embeddings[i] - vllm_emb = vllm_embeddings[i] - - # Compute cosine similarity - similarity = torch.nn.functional.cosine_similarity( - hf_emb.unsqueeze(0), vllm_emb.unsqueeze(0) - ).item() - - # Compute max absolute difference - max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item() - - similarities.append(similarity) - max_diffs.append(max_diff) - - print(f"\nTest case {i + 1}: {case_type} - {desc[:50]}...") - print(f" Cosine similarity: {similarity:.6f}") - print(f" Max absolute diff: {max_diff:.6f}") - print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}") - - # Flag significant differences - if similarity < 0.99: - print(" ⚠️ WARNING: Low similarity detected!") - - # Summary statistics - print("\n" + "-" * 60) - print("SUMMARY STATISTICS") - print("-" * 60) - print(f"Average cosine similarity: {np.mean(similarities):.6f}") - print(f"Min cosine similarity: {np.min(similarities):.6f}") - print(f"Max absolute difference: {np.max(max_diffs):.6f}") - - # Overall assessment - if np.min(similarities) > 0.99: - print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace") - else: - print("\n❌ VALIDATION FAILED: Significant differences detected") - - -def main(): - parser = argparse.ArgumentParser( - description="Validate Jina Embeddings V4 implementation" - ) - parser.add_argument( - "--model", - type=str, - default="jinaai/jina-embeddings-v4-vllm-retrieval", - help="Model name to test", - ) - parser.add_argument( - "--skip-hf", - action="store_true", - help="Skip HuggingFace comparison (for performance testing only)", - ) - - args = parser.parse_args() - - # Create test cases - test_cases = create_test_cases() - print(f"Created {len(test_cases)} test cases") - - # Compute vLLM embeddings - vllm_embeddings = compute_vllm_embeddings(args.model, test_cases) - - if not args.skip_hf: - # Compute HuggingFace embeddings - hf_embeddings = compute_hf_embeddings(args.model, test_cases) - - # Compare results - compare_embeddings(hf_embeddings, vllm_embeddings, test_cases) - else: - print("\nSkipping HuggingFace comparison") - print(f"vLLM processed {len(test_cases)} embeddings successfully") - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/jina_embeddings_v4.py b/examples/offline_inference/vision_language_embedding.py similarity index 100% rename from examples/offline_inference/jina_embeddings_v4.py rename to examples/offline_inference/vision_language_embedding.py diff --git a/tests/models/registry.py b/tests/models/registry.py index d2e70e291df3..299f2806a39b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -316,6 +316,8 @@ def check_available_online( "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 # [Multimodal] + "JinaVLForEmbedding": _HfExamplesInfo("jinaai/jina-embeddings-v4-vllm-retrieval", # noqa: E501 + trust_remote_code=True), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 5dc47d418671..000b6eb63079 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -59,13 +59,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Configuration for vision pooling backend self.pooling_backend = getattr(vllm_config.model_config, - "jina_pooling_backend", "triton") + "jina_pooling_backend", "pytorch") if self.pooling_backend not in ("triton", "pytorch"): logger.warning( "Invalid jina_pooling_backend '%s'. " - "Must be 'triton' or 'pytorch'. Defaulting to 'triton'.", + "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", self.pooling_backend) - self.pooling_backend = "triton" + self.pooling_backend = "pytorch" # Initialize base pooler for fallback self._base_pooler = Pooler.from_config_with_defaults( From 0c3f1bd6c7d452e3d316fa69d74b32c61dcf85cc Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 10:57:19 +0900 Subject: [PATCH 16/23] refactor: import HAS_TRITON from triton_utils instead of local definition Import HAS_TRITON directly from vllm.triton_utils to maintain consistency across the codebase and avoid redundant definitions. Signed-off-by: Sigrid Jin (Sionic AI) --- vllm/model_executor/layers/pooler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 278243e14fe2..0ce562cbb02c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -15,12 +15,10 @@ PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput -from vllm.triton_utils import tl, triton +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata -HAS_TRITON = triton is not None - PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] From 5c45015c2a772fd566316e4683269cabdfc24ca3 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 11:59:22 +0900 Subject: [PATCH 17/23] refactor: rename example file to follow existing embedding pattern Rename vision_language_embedding.py to embed_jina_embeddings_v4.py to follow the established naming pattern for embedding examples (e.g., embed_jina_embeddings_v3.py, embed_matryoshka_fy.py). Also update the docstring to be more specific about Jina Embeddings V4 multimodal capabilities. Signed-off-by: Sigrid Jin (Sionic AI) --- ...on_language_embedding.py => embed_jina_embeddings_v4.py} | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) rename examples/offline_inference/{vision_language_embedding.py => embed_jina_embeddings_v4.py} (95%) diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/embed_jina_embeddings_v4.py similarity index 95% rename from examples/offline_inference/vision_language_embedding.py rename to examples/offline_inference/embed_jina_embeddings_v4.py index c3716b5e09f3..0d20953c8abb 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/embed_jina_embeddings_v4.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Example of using Jina Embeddings V4 with vLLM for multimodal embeddings. +Example of offline inference with Jina Embeddings V4 multimodal model. This example demonstrates: 1. Text-only embeddings 2. Image-only embeddings -3. Mixed text and image embeddings +3. Cross-modal embeddings (text-to-image similarity) + +The model supports both text and vision inputs through a unified architecture. """ import torch From 8e0578a7d2492a8901368adee92ffb1fc6a5030f Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 15:45:47 +0900 Subject: [PATCH 18/23] refactor: update JinaVLForEmbedding to comply with new pooling architecture Update JinaVLForEmbedding to align with PR #21058's pooling model interface: - Add is_pooling_model = True class attribute - Create JinaVLPooler class inheriting from Pooler base class - Move vision-aware pooling logic into JinaVLPooler - Implement get_pooling_params method returning PoolingParams() for "embed" task - Replace pooler method with pooler attribute - Add required imports: PoolingTask, PoolingParams, assert_never The JinaVLPooler maintains the sophisticated vision-text pooling behavior while conforming to the new architecture requirements. Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 193 ++++++++++-------- 1 file changed, 111 insertions(+), 82 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 000b6eb63079..b9d26902dbe0 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -7,10 +7,12 @@ import torch import torch.nn.functional as F +from typing_extensions import assert_never from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, +from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask, + PoolingType, extract_vision_tokens_kernel) # yapf: disable from vllm.model_executor.pooling_metadata import ( @@ -18,6 +20,7 @@ from vllm.model_executor.pooling_metadata import PoolingTensors # yapf: enable from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata @@ -36,49 +39,98 @@ PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, - SupportsCrossEncoding, SupportsMultiModal): - # Weight mapping for HuggingFace checkpoint compatibility - weight_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.": "language_model.model.", - "visual.": "visual.", - "lm_head.": "language_model.lm_head.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "qwen2_vl")) +class JinaVLPooler(Pooler): + """Vision-aware pooler for Jina V4 with special vision token handling.""" + def __init__(self, + vllm_config: VllmConfig, + pooling_backend: str = "pytorch"): + super().__init__() self.hidden_size = vllm_config.model_config.hf_config.hidden_size - pooler_config = vllm_config.model_config.pooler_config + self.pooling_backend = pooling_backend self.observability_config = vllm_config.observability_config - # Configuration for vision pooling backend - self.pooling_backend = getattr(vllm_config.model_config, - "jina_pooling_backend", "pytorch") - if self.pooling_backend not in ("triton", "pytorch"): - logger.warning( - "Invalid jina_pooling_backend '%s'. " - "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", - self.pooling_backend) - self.pooling_backend = "pytorch" + # Performance tracking + self._pooling_time_ms = 0.0 + self._pooling_count = 0 # Initialize base pooler for fallback + pooler_config = vllm_config.model_config.pooler_config self._base_pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, normalize=True, softmax=False) - # Performance tracking - self._pooling_time_ms = 0.0 - self._pooling_count = 0 + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + """Return pooling params for embedding task.""" + if task == "embed": + return PoolingParams() - logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") + # The equalities are split up to keep mypy happy + if task == "encode" or task == "classify" or task == "score": + return None + + assert_never(task) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Apply vision-aware pooling to hidden states.""" + start_time = time.time() if self.observability_config else None + + # Validate inputs + if hidden_states is None or hidden_states.numel() == 0: + logger.warning("Empty hidden states received") + return PoolerOutput(outputs=[]) + + # Extract token IDs safely from metadata + token_ids_list, seq_ids = self._extract_token_ids_safe( + pooling_metadata) + + if not token_ids_list: + logger.warning("No valid sequences found for pooling") + # Fallback to base pooler + return self._base_pooler(hidden_states, pooling_metadata) + + # Get prompt lengths based on metadata type + if isinstance(pooling_metadata, V1PoolingMetadata): + prompt_lens = pooling_metadata.prompt_lens + else: + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + # Validate lengths match + assert len(token_ids_list) == len(prompt_lens), ( + f"Mismatch: {len(token_ids_list)} sequences vs " + f"{len(prompt_lens)} lengths") + + # Apply pooling based on configured backend + if self.pooling_backend == "triton": + pooled_data = self._apply_vision_pooling_optimized( + hidden_states, token_ids_list, prompt_lens) + else: # self.pooling_backend == "pytorch" + pooled_data = self._apply_vision_pooling_pytorch( + hidden_states, token_ids_list, prompt_lens) + + # Build output + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + # Record metrics + if self.observability_config: + elapsed_ms = (time.time() - start_time) * 1000 + self._pooling_time_ms += elapsed_ms + self._pooling_count += 1 + + if self._pooling_count % 100 == 0: + avg_time = self._pooling_time_ms / self._pooling_count + logger.debug("Average pooling time: %.2fms", avg_time) + + return PoolerOutput(outputs=pooled_outputs) def _extract_token_ids_safe( self, pooling_metadata: PoolingMetadata @@ -239,64 +291,41 @@ def _apply_vision_pooling_pytorch( return pooled_outputs - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - """Thread-safe pooler with production error handling.""" - start_time = time.time() if self.observability_config else None - - # Validate inputs - if hidden_states is None or hidden_states.numel() == 0: - logger.warning("Empty hidden states received") - return PoolerOutput(outputs=[]) - - # Extract token IDs safely from metadata - token_ids_list, seq_ids = self._extract_token_ids_safe( - pooling_metadata) - if not token_ids_list: - logger.warning("No valid sequences found for pooling") - # Fallback to base pooler - return self._base_pooler(hidden_states, pooling_metadata) - - # Get prompt lengths based on metadata type - if isinstance(pooling_metadata, V1PoolingMetadata): - prompt_lens = pooling_metadata.prompt_lens - else: - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, SupportsMultiModal): - # Validate lengths match - assert len(token_ids_list) == len(prompt_lens), ( - f"Mismatch: {len(token_ids_list)} sequences vs " - f"{len(prompt_lens)} lengths") + is_pooling_model = True - # Apply pooling based on configured backend - if self.pooling_backend == "triton": - pooled_data = self._apply_vision_pooling_optimized( - hidden_states, token_ids_list, prompt_lens) - else: # self.pooling_backend == "pytorch" - pooled_data = self._apply_vision_pooling_pytorch( - hidden_states, token_ids_list, prompt_lens) + # Weight mapping for HuggingFace checkpoint compatibility + weight_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "visual.": "visual.", + "lm_head.": "language_model.lm_head.", + }) - # Build output - pooled_outputs = [ - PoolingSequenceGroupOutput(data) for data in pooled_data - ] + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "qwen2_vl")) - # Record metrics - if self.observability_config: - elapsed_ms = (time.time() - start_time) * 1000 - self._pooling_time_ms += elapsed_ms - self._pooling_count += 1 + # Configuration for vision pooling backend + self.pooling_backend = getattr(vllm_config.model_config, + "jina_pooling_backend", "pytorch") + if self.pooling_backend not in ("triton", "pytorch"): + logger.warning( + "Invalid jina_pooling_backend '%s'. " + "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", + self.pooling_backend) + self.pooling_backend = "pytorch" - if self._pooling_count % 100 == 0: - avg_time = self._pooling_time_ms / self._pooling_count - logger.debug("Average pooling time: %.2fms", avg_time) + # Initialize the vision-aware pooler + self.pooler = JinaVLPooler(vllm_config, self.pooling_backend) - return PoolerOutput(outputs=pooled_outputs) + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights with validation and error handling.""" From eb1497e3bff7bb01f4e25dc95f7a9ed239739a45 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 15:58:37 +0900 Subject: [PATCH 19/23] refactor: use pooler utility functions to avoid duplicate code Use build_output and get_prompt_lens from pooler.py instead of implementing duplicate logic: - Replace manual PoolerOutput construction with build_output - Replace prompt length extraction logic with get_prompt_lens - Remove unused imports (PoolingSequenceGroupOutput, PoolingTensors) This addresses the review feedback to avoid duplicate code. Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index b9d26902dbe0..f25e3b67b2fe 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -12,16 +12,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask, - PoolingType, - extract_vision_tokens_kernel) + PoolingType, build_output, + extract_vision_tokens_kernel, + get_prompt_lens) # yapf: disable from vllm.model_executor.pooling_metadata import ( PoolingMetadata as V0PoolingMetadata) -from vllm.model_executor.pooling_metadata import PoolingTensors # yapf: enable from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.sequence import PoolerOutput from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from .interfaces import SupportsCrossEncoding, SupportsMultiModal @@ -84,7 +84,7 @@ def forward( # Validate inputs if hidden_states is None or hidden_states.numel() == 0: logger.warning("Empty hidden states received") - return PoolerOutput(outputs=[]) + return build_output(torch.empty((0, 0))) # Extract token IDs safely from metadata token_ids_list, seq_ids = self._extract_token_ids_safe( @@ -95,12 +95,8 @@ def forward( # Fallback to base pooler return self._base_pooler(hidden_states, pooling_metadata) - # Get prompt lengths based on metadata type - if isinstance(pooling_metadata, V1PoolingMetadata): - prompt_lens = pooling_metadata.prompt_lens - else: - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + # Get prompt lengths using utility function + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) # Validate lengths match assert len(token_ids_list) == len(prompt_lens), ( @@ -115,10 +111,8 @@ def forward( pooled_data = self._apply_vision_pooling_pytorch( hidden_states, token_ids_list, prompt_lens) - # Build output - pooled_outputs = [ - PoolingSequenceGroupOutput(data) for data in pooled_data - ] + # Stack pooled data into tensor for build_output + pooled_tensor = torch.stack(pooled_data) # Record metrics if self.observability_config: @@ -130,7 +124,7 @@ def forward( avg_time = self._pooling_time_ms / self._pooling_count logger.debug("Average pooling time: %.2fms", avg_time) - return PoolerOutput(outputs=pooled_outputs) + return build_output(pooled_tensor) def _extract_token_ids_safe( self, pooling_metadata: PoolingMetadata From 1b4f405fdad779a0f70d9e39c907ce5b4d653f29 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 16:18:33 +0900 Subject: [PATCH 20/23] refactor: address maintainer review comments for JinaVLPooler Address DarkLight1337's review feedback: - Set logits_processing_needs_token_ids=True for V1 compatibility in both "embed" and "encode" tasks - Support "encode" task by returning PoolingParams() instead of None - Update log message from "thread-safe pooling" to "vision-aware pooling" to better reflect the actual functionality - Remove unused seq_ids variable from _extract_token_ids_safe method These changes ensure proper V1 compatibility and cleaner code structure. Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/jina_embeddings_v4.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index f25e3b67b2fe..875ce0e9de05 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -65,10 +65,13 @@ def __init__(self, def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: """Return pooling params for embedding task.""" if task == "embed": - return PoolingParams() + return PoolingParams(logits_processing_needs_token_ids=True) + + if task == "encode": + return PoolingParams(logits_processing_needs_token_ids=True) # The equalities are split up to keep mypy happy - if task == "encode" or task == "classify" or task == "score": + if task == "classify" or task == "score": return None assert_never(task) @@ -87,8 +90,7 @@ def forward( return build_output(torch.empty((0, 0))) # Extract token IDs safely from metadata - token_ids_list, seq_ids = self._extract_token_ids_safe( - pooling_metadata) + token_ids_list = self._extract_token_ids_safe(pooling_metadata) if not token_ids_list: logger.warning("No valid sequences found for pooling") @@ -127,24 +129,20 @@ def forward( return build_output(pooled_tensor) def _extract_token_ids_safe( - self, pooling_metadata: PoolingMetadata - ) -> tuple[list[array], list[int]]: + self, pooling_metadata: PoolingMetadata) -> list[array]: """Safely extract token IDs from pooling metadata.""" token_ids_list: list[array] = [] try: if isinstance(pooling_metadata, V1PoolingMetadata): - # For V1, we get token IDs and sequence indices directly + # For V1, we get token IDs directly for i, num in enumerate(pooling_metadata.prompt_lens): token_ids = pooling_metadata.prompt_token_ids[ i, :num].tolist() token_ids_list.append(array('l', token_ids)) - # V1 metadata does not have explicit seq_ids, so we use indices - seq_ids = list(range(len(token_ids_list))) - return token_ids_list, seq_ids + return token_ids_list # For V0, we extract from seq_groups and seq_data - seq_ids = [] for seq_group, _ in pooling_metadata.seq_groups: for seq_id in seq_group: if seq_id not in pooling_metadata.seq_data: @@ -164,10 +162,9 @@ def _extract_token_ids_safe( seq_id) continue - seq_ids.append(seq_id) token_ids_list.append(token_ids) - return token_ids_list, seq_ids + return token_ids_list except Exception as e: logger.error( @@ -319,7 +316,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Initialize the vision-aware pooler self.pooler = JinaVLPooler(vllm_config, self.pooling_backend) - logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") + logger.info("Initialized JinaVLForEmbedding with vision-aware pooling") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): """Load weights with validation and error handling.""" From 702fd162f7988c3b439d93766cf024e2b46f2bc8 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 17:04:44 +0900 Subject: [PATCH 21/23] perf: optimize vision token detection using torch.isin Implement efficiency improvements suggested by DarkLight1337: - Consolidate get_pooling_params method for "embed" and "encode" tasks - Pre-compute vision token IDs tensor in constructor - Replace range checks with torch.isin for more efficient vision token detection at lines 209-210 and 261-262 This reduces redundant code and improves performance when checking for vision tokens by using optimized tensor operations. Signed-off-by: Sigrid Jin (Sionic AI) --- .../model_executor/models/jina_embeddings_v4.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 875ce0e9de05..d90f07a0ca9f 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -50,6 +50,10 @@ def __init__(self, self.pooling_backend = pooling_backend self.observability_config = vllm_config.observability_config + # Pre-compute vision token IDs tensor for efficient checking + self.vision_token_ids = torch.tensor( + [VISION_START_TOKEN_ID, VISION_END_TOKEN_ID], dtype=torch.long) + # Performance tracking self._pooling_time_ms = 0.0 self._pooling_count = 0 @@ -64,10 +68,7 @@ def __init__(self, def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: """Return pooling params for embedding task.""" - if task == "embed": - return PoolingParams(logits_processing_needs_token_ids=True) - - if task == "encode": + if task == "embed" or task == "encode": return PoolingParams(logits_processing_needs_token_ids=True) # The equalities are split up to keep mypy happy @@ -206,8 +207,8 @@ def _apply_vision_pooling_optimized( dtype=hidden_states.dtype) # Check for vision tokens - has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID) - & (token_tensor <= VISION_END_TOKEN_ID)) + has_vision = torch.isin(token_tensor, + self.vision_token_ids.to(device)).any() if has_vision: # Use Triton kernel for vision token extraction @@ -258,8 +259,8 @@ def _apply_vision_pooling_pytorch( device=hidden_states.device) # Check for vision tokens - vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) & - (seq_tokens <= VISION_END_TOKEN_ID)) + vision_mask = torch.isin( + seq_tokens, self.vision_token_ids.to(seq_tokens.device)) if vision_mask.any(): # Pool only vision tokens From 5114a3c1c9d176588946139274e68c43a60bdd3f Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Fri, 18 Jul 2025 17:49:23 +0900 Subject: [PATCH 22/23] fix: introducing dedicated VisionPooler class Signed-off-by: Sigrid Jin (Sionic AI) --- .../models/pooling/test_jina_embeddings_v4.py | 37 +++ vllm/config.py | 5 +- vllm/model_executor/layers/pooler.py | 159 ++++++++-- .../models/jina_embeddings_v4.py | 275 +----------------- 4 files changed, 189 insertions(+), 287 deletions(-) diff --git a/tests/models/pooling/test_jina_embeddings_v4.py b/tests/models/pooling/test_jina_embeddings_v4.py index 6baa8d859d75..35c84acc2ec8 100644 --- a/tests/models/pooling/test_jina_embeddings_v4.py +++ b/tests/models/pooling/test_jina_embeddings_v4.py @@ -342,3 +342,40 @@ def test_vision_only_pooling(self, model): # embeddings should be very similar despite different text similarity = torch.dot(emb1, emb2).item() assert similarity > 0.99 # Should be nearly identical + + +class TestVisionPooler: + """Test the VisionPooler class.""" + + def test_vision_pooler(self): + """Test that the VisionPooler correctly pools vision tokens.""" + from vllm.config import ModelConfig + from vllm.model_executor.layers.pooler import VisionPooler + from vllm.pooling_params import PoolingParams + from vllm.v1.pool.metadata import PoolingMetadata + + model_config = ModelConfig(model_name, task="embed") + model_config.hf_config.vision_start_token_id = VISION_START_TOKEN_ID + model_config.hf_config.vision_end_token_id = VISION_END_TOKEN_ID + model_config.hidden_size = 4 + + pooler = VisionPooler(model_config) + + hidden_states = torch.randn(10, 4) + prompt_token_ids = torch.tensor([[ + 1, 2, VISION_START_TOKEN_ID, 4, VISION_END_TOKEN_ID, 6, 7, 8, 9, 10 + ]]) + prompt_lens = torch.tensor([10]) + + pooling_metadata = PoolingMetadata(prompt_lens=prompt_lens, + prompt_token_ids=prompt_token_ids, + pooling_params=[PoolingParams()]) + + output = pooler.forward(hidden_states, pooling_metadata) + + vision_tokens = hidden_states[2:5] + expected_output = vision_tokens.mean(dim=0) + + assert torch.allclose(output.outputs[0].data, + expected_output, + atol=1e-5) diff --git a/vllm/config.py b/vllm/config.py index 526b5db235fd..ec7ef4f7f550 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3256,9 +3256,10 @@ def get_limit_per_prompt(self, modality: str) -> int: @config @dataclass class PoolerConfig: - """Controls the behavior of output pooling in pooling models.""" + """Configuration for the pooler.""" - pooling_type: Optional[str] = None + pooling_type: Optional[Literal["last", "all", "cls", "step", "mean", + "vision"]] = None """ The pooling method of the pooling model. This should be a key in [`vllm.model_executor.layers.pooler.PoolingType`][]. diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 38a1d44828c2..7036392dc463 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -32,6 +32,7 @@ class PoolingType(IntEnum): CLS = 2 STEP = 3 MEAN = 4 + VISION = 5 @dataclass(frozen=True) @@ -91,6 +92,8 @@ def from_config_with_defaults( if pooling_type == PoolingType.STEP: return StepPooler.from_config(resolved_config) + if pooling_type == PoolingType.VISION: + return VisionPooler.from_config(resolved_config) return SimplePooler.from_config(resolved_config) @@ -622,6 +625,86 @@ def forward( ClassifierFn = Callable[[torch.Tensor], torch.Tensor] +class VisionPooler(Pooler): + + @classmethod + def from_config(cls, model_config: ModelConfig) -> "VisionPooler": + return cls(model_config) + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "embed": + return PoolingParams(pooling_type="vision", + logits_processing_needs_token_ids=True) + return None + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + assert isinstance(pooling_metadata, V1PoolingMetadata) + + pooled_outputs = [] + for i in range(len(pooling_metadata.prompt_lens)): + start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. + hf_config.vision_start_token_id).nonzero()[-1].item() + end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. + hf_config.vision_end_token_id).nonzero()[-1].item() + + seq_start = torch.cumsum( + torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), + dim=0)[i] + seq_len = pooling_metadata.prompt_lens[i] + + output = torch.empty(self.config.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype) + + grid = lambda meta: (self.config.hidden_size, ) + mean_pool_with_position_kernel[grid](hidden_states, output, + seq_start, seq_len, + self.config.hidden_size, + start_pos, end_pos + 1) + + pooled_outputs.append(output) + + return build_output(torch.stack(pooled_outputs)) + + +if HAS_TRITON: + + @triton.jit + def mean_pool_with_position_kernel( + hidden_states_ptr, + output_ptr, + seq_start, + seq_len, + hidden_size, + pool_start, + pool_end, + BLOCK_SIZE: tl.constexpr, + ): + """Triton kernel to perform mean pooling over a specified token range.""" + pid = tl.program_id(0) + + if pid >= hidden_size: + return + + accumulator = 0.0 + for i in range(pool_start, pool_end): + hidden_val = tl.load(hidden_states_ptr + + (seq_start + i) * hidden_size + pid) + accumulator += hidden_val + + # Store mean pooled result + result = accumulator / (pool_end - pool_start) + tl.store(output_ptr + pid, result) + + class ClassifierPooler(nn.Module): """A pooling layer for classification tasks. @@ -709,39 +792,81 @@ def forward( return build_output(scores) +class VisionPooler(Pooler): + + @classmethod + def from_config(cls, model_config: ModelConfig) -> "VisionPooler": + return cls(model_config) + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + if task == "embed": + return PoolingParams(pooling_type="vision", + logits_processing_needs_token_ids=True) + return None + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + assert isinstance(pooling_metadata, V1PoolingMetadata) + + pooled_outputs = [] + for i in range(len(pooling_metadata.prompt_lens)): + start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. + hf_config.vision_start_token_id).nonzero()[-1].item() + end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. + hf_config.vision_end_token_id).nonzero()[-1].item() + + seq_start = torch.cumsum( + torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), + dim=0)[i] + seq_len = pooling_metadata.prompt_lens[i] + + output = torch.empty(self.config.hidden_size, + device=hidden_states.device, + dtype=hidden_states.dtype) + + grid = lambda meta: (self.config.hidden_size, ) + mean_pool_with_position_kernel[grid](hidden_states, output, + seq_start, seq_len, + self.config.hidden_size, + start_pos, end_pos + 1) + + pooled_outputs.append(output) + + return build_output(torch.stack(pooled_outputs)) + + if HAS_TRITON: @triton.jit - def extract_vision_tokens_kernel( + def mean_pool_with_position_kernel( hidden_states_ptr, - token_ids_ptr, output_ptr, seq_start, seq_len, hidden_size, - vision_start_id: tl.constexpr, - vision_end_id: tl.constexpr, + pool_start, + pool_end, BLOCK_SIZE: tl.constexpr, ): - """Triton kernel to extract and pool vision tokens efficiently.""" + """Triton kernel to perform mean pooling over a specified token range.""" pid = tl.program_id(0) if pid >= hidden_size: return - # Find vision token range - vision_count = 0 accumulator = 0.0 - - for i in range(seq_len): - token_id = tl.load(token_ids_ptr + seq_start + i) - if token_id >= vision_start_id and token_id <= vision_end_id: - hidden_val = tl.load(hidden_states_ptr + - (seq_start + i) * hidden_size + pid) - accumulator += hidden_val - vision_count += 1 + for i in range(pool_start, pool_end): + hidden_val = tl.load(hidden_states_ptr + + (seq_start + i) * hidden_size + pid) + accumulator += hidden_val # Store mean pooled result - result = accumulator / vision_count if vision_count > 0 else 0.0 - + result = accumulator / (pool_end - pool_start) tl.store(output_ptr + pid, result) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index d90f07a0ca9f..1d543047cb17 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -1,20 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from array import array -from collections.abc import Iterable from typing import Optional, Union import torch -import torch.nn.functional as F -from typing_extensions import assert_never from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask, - PoolingType, build_output, - extract_vision_tokens_kernel, - get_prompt_lens) +from vllm.model_executor.layers.pooler import Pooler, PoolingTask # yapf: disable from vllm.model_executor.pooling_metadata import ( PoolingMetadata as V0PoolingMetadata) @@ -28,7 +20,7 @@ from .qwen2_vl import (Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import maybe_prefix logger = init_logger(__name__) @@ -42,246 +34,19 @@ class JinaVLPooler(Pooler): """Vision-aware pooler for Jina V4 with special vision token handling.""" - def __init__(self, - vllm_config: VllmConfig, - pooling_backend: str = "pytorch"): + def __init__(self, vllm_config: VllmConfig): super().__init__() - self.hidden_size = vllm_config.model_config.hf_config.hidden_size - self.pooling_backend = pooling_backend - self.observability_config = vllm_config.observability_config - - # Pre-compute vision token IDs tensor for efficient checking - self.vision_token_ids = torch.tensor( - [VISION_START_TOKEN_ID, VISION_END_TOKEN_ID], dtype=torch.long) - - # Performance tracking - self._pooling_time_ms = 0.0 - self._pooling_count = 0 - - # Initialize base pooler for fallback - pooler_config = vllm_config.model_config.pooler_config - self._base_pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.MEAN, - normalize=True, - softmax=False) + self.vision_pooler = VisionPooler(vllm_config.model_config) def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - """Return pooling params for embedding task.""" - if task == "embed" or task == "encode": - return PoolingParams(logits_processing_needs_token_ids=True) - - # The equalities are split up to keep mypy happy - if task == "classify" or task == "score": - return None - - assert_never(task) + return self.vision_pooler.get_pooling_params(task) def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - """Apply vision-aware pooling to hidden states.""" - start_time = time.time() if self.observability_config else None - - # Validate inputs - if hidden_states is None or hidden_states.numel() == 0: - logger.warning("Empty hidden states received") - return build_output(torch.empty((0, 0))) - - # Extract token IDs safely from metadata - token_ids_list = self._extract_token_ids_safe(pooling_metadata) - - if not token_ids_list: - logger.warning("No valid sequences found for pooling") - # Fallback to base pooler - return self._base_pooler(hidden_states, pooling_metadata) - - # Get prompt lengths using utility function - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - - # Validate lengths match - assert len(token_ids_list) == len(prompt_lens), ( - f"Mismatch: {len(token_ids_list)} sequences vs " - f"{len(prompt_lens)} lengths") - - # Apply pooling based on configured backend - if self.pooling_backend == "triton": - pooled_data = self._apply_vision_pooling_optimized( - hidden_states, token_ids_list, prompt_lens) - else: # self.pooling_backend == "pytorch" - pooled_data = self._apply_vision_pooling_pytorch( - hidden_states, token_ids_list, prompt_lens) - - # Stack pooled data into tensor for build_output - pooled_tensor = torch.stack(pooled_data) - - # Record metrics - if self.observability_config: - elapsed_ms = (time.time() - start_time) * 1000 - self._pooling_time_ms += elapsed_ms - self._pooling_count += 1 - - if self._pooling_count % 100 == 0: - avg_time = self._pooling_time_ms / self._pooling_count - logger.debug("Average pooling time: %.2fms", avg_time) - - return build_output(pooled_tensor) - - def _extract_token_ids_safe( - self, pooling_metadata: PoolingMetadata) -> list[array]: - """Safely extract token IDs from pooling metadata.""" - token_ids_list: list[array] = [] - try: - if isinstance(pooling_metadata, V1PoolingMetadata): - # For V1, we get token IDs directly - for i, num in enumerate(pooling_metadata.prompt_lens): - token_ids = pooling_metadata.prompt_token_ids[ - i, :num].tolist() - token_ids_list.append(array('l', token_ids)) - - return token_ids_list - - # For V0, we extract from seq_groups and seq_data - for seq_group, _ in pooling_metadata.seq_groups: - for seq_id in seq_group: - if seq_id not in pooling_metadata.seq_data: - logger.warning("Sequence %s not found in seq_data", - seq_id) - continue - - seq_data = pooling_metadata.seq_data[seq_id] - - # Get prompt token IDs safely - if hasattr(seq_data, 'prompt_token_ids_array'): - token_ids = seq_data.prompt_token_ids_array - elif hasattr(seq_data, '_prompt_token_ids'): - token_ids = seq_data._prompt_token_ids - else: - logger.warning("No token IDs found for sequence %s", - seq_id) - continue - - token_ids_list.append(token_ids) - - return token_ids_list - - except Exception as e: - logger.error( - "Error extracting token IDs: %s. " - "Extracted %d sequences before failure", e, - len(token_ids_list)) - raise - - def _apply_vision_pooling_optimized( - self, - hidden_states: torch.Tensor, - token_ids_list: list[array], - prompt_lens: torch.Tensor, - ) -> list[torch.Tensor]: - """Apply optimized vision token pooling using Triton kernels.""" - if not HAS_TRITON: - logger.debug( - "Triton not available, falling back to PyTorch implementation") - return self._apply_vision_pooling_pytorch(hidden_states, - token_ids_list, - prompt_lens) - - pooled_outputs = [] - offset = 0 - device = hidden_states.device - - for i, (token_ids, - prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): - prompt_len = int(prompt_len.item()) - - # Convert token IDs to tensor - token_tensor = torch.tensor(list(token_ids), - dtype=torch.long, - device=device) - - # Allocate output tensor - output = torch.zeros(self.hidden_size, - device=device, - dtype=hidden_states.dtype) - - # Check for vision tokens - has_vision = torch.isin(token_tensor, - self.vision_token_ids.to(device)).any() - - if has_vision: - # Use Triton kernel for vision token extraction - grid = (self.hidden_size, ) - extract_vision_tokens_kernel[grid]( - hidden_states, - token_tensor, - output, - offset, - prompt_len, - self.hidden_size, - VISION_START_TOKEN_ID, - VISION_END_TOKEN_ID, - BLOCK_SIZE=1024, - ) - else: - # Regular mean pooling for text - seq_states = hidden_states[offset:offset + prompt_len] - output = seq_states.mean(dim=0) - - # Normalize and handle potential NaNs by replacing with zeros - output = F.normalize(output, p=2, dim=-1, eps=1e-12) - pooled_outputs.append(output) - - offset += prompt_len - - return pooled_outputs - - def _apply_vision_pooling_pytorch( - self, - hidden_states: torch.Tensor, - token_ids_list: list[array], - prompt_lens: torch.Tensor, - ) -> list[torch.Tensor]: - """PyTorch fallback for vision token pooling.""" - pooled_outputs = [] - offset = 0 - - for token_ids, prompt_len in zip(token_ids_list, prompt_lens): - prompt_len = int(prompt_len.item()) - - # Extract sequence states and tokens - seq_states = hidden_states[offset:offset + prompt_len] - - # Convert array to tensor for processing - seq_tokens = torch.tensor(list(token_ids[:prompt_len]), - dtype=torch.long, - device=hidden_states.device) - - # Check for vision tokens - vision_mask = torch.isin( - seq_tokens, self.vision_token_ids.to(seq_tokens.device)) - - if vision_mask.any(): - # Pool only vision tokens - vision_states = seq_states[vision_mask] - if vision_states.numel() == 0: - logger.warning( - "No vision states found despite vision mask") - pooled = seq_states.mean(dim=0) - else: - pooled = vision_states.mean(dim=0) - else: - # Pool all tokens for text - pooled = seq_states.mean(dim=0) - - # Normalize embeddings - pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12) - pooled_outputs.append(pooled) - - offset += prompt_len - - return pooled_outputs + return self.vision_pooler.forward(hidden_states, pooling_metadata) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, @@ -292,36 +57,10 @@ class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, is_pooling_model = True - # Weight mapping for HuggingFace checkpoint compatibility - weight_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.": "language_model.model.", - "visual.": "visual.", - "lm_head.": "language_model.lm_head.", - }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")) - # Configuration for vision pooling backend - self.pooling_backend = getattr(vllm_config.model_config, - "jina_pooling_backend", "pytorch") - if self.pooling_backend not in ("triton", "pytorch"): - logger.warning( - "Invalid jina_pooling_backend '%s'. " - "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", - self.pooling_backend) - self.pooling_backend = "pytorch" - - # Initialize the vision-aware pooler - self.pooler = JinaVLPooler(vllm_config, self.pooling_backend) + self.pooler = JinaVLPooler(vllm_config) logger.info("Initialized JinaVLForEmbedding with vision-aware pooling") - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - """Load weights with validation and error handling.""" - loader = AutoWeightsLoader(self) - loaded_weights = loader.load_weights(weights, - mapper=self.weight_mapper) - return loaded_weights From 6b501b2c2952a6eb468c80344e0bcd231a7c7463 Mon Sep 17 00:00:00 2001 From: "Sigrid Jin (Sionic AI)" Date: Sat, 19 Jul 2025 13:51:48 +0900 Subject: [PATCH 23/23] feat: add vision pooling support for jina embeddings v4 Signed-off-by: Sigrid Jin (Sionic AI) --- .../embed_jina_embeddings_v4.py | 11 -- vllm/config.py | 2 +- vllm/model_executor/layers/pooler.py | 102 +++--------------- .../models/jina_embeddings_v4.py | 4 +- 4 files changed, 17 insertions(+), 102 deletions(-) diff --git a/examples/offline_inference/embed_jina_embeddings_v4.py b/examples/offline_inference/embed_jina_embeddings_v4.py index 0d20953c8abb..69ebe83d7588 100644 --- a/examples/offline_inference/embed_jina_embeddings_v4.py +++ b/examples/offline_inference/embed_jina_embeddings_v4.py @@ -1,16 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Example of offline inference with Jina Embeddings V4 multimodal model. - -This example demonstrates: -1. Text-only embeddings -2. Image-only embeddings -3. Cross-modal embeddings (text-to-image similarity) - -The model supports both text and vision inputs through a unified architecture. -""" - import torch from vllm import LLM diff --git a/vllm/config.py b/vllm/config.py index ec7ef4f7f550..f919a3f5463a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3256,7 +3256,7 @@ def get_limit_per_prompt(self, modality: str) -> int: @config @dataclass class PoolerConfig: - """Configuration for the pooler.""" + """Controls the behavior of output pooling in pooling models.""" pooling_type: Optional[Literal["last", "all", "cls", "step", "mean", "vision"]] = None diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 7036392dc463..34337aa6cdea 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -625,56 +625,6 @@ def forward( ClassifierFn = Callable[[torch.Tensor], torch.Tensor] -class VisionPooler(Pooler): - - @classmethod - def from_config(cls, model_config: ModelConfig) -> "VisionPooler": - return cls(model_config) - - def __init__(self, config: ModelConfig): - super().__init__() - self.config = config - - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - if task == "embed": - return PoolingParams(pooling_type="vision", - logits_processing_needs_token_ids=True) - return None - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - assert isinstance(pooling_metadata, V1PoolingMetadata) - - pooled_outputs = [] - for i in range(len(pooling_metadata.prompt_lens)): - start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. - hf_config.vision_start_token_id).nonzero()[-1].item() - end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. - hf_config.vision_end_token_id).nonzero()[-1].item() - - seq_start = torch.cumsum( - torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), - dim=0)[i] - seq_len = pooling_metadata.prompt_lens[i] - - output = torch.empty(self.config.hidden_size, - device=hidden_states.device, - dtype=hidden_states.dtype) - - grid = lambda meta: (self.config.hidden_size, ) - mean_pool_with_position_kernel[grid](hidden_states, output, - seq_start, seq_len, - self.config.hidden_size, - start_pos, end_pos + 1) - - pooled_outputs.append(output) - - return build_output(torch.stack(pooled_outputs)) - - if HAS_TRITON: @triton.jit @@ -688,7 +638,6 @@ def mean_pool_with_position_kernel( pool_end, BLOCK_SIZE: tl.constexpr, ): - """Triton kernel to perform mean pooling over a specified token range.""" pid = tl.program_id(0) if pid >= hidden_size: @@ -817,10 +766,12 @@ def forward( pooled_outputs = [] for i in range(len(pooling_metadata.prompt_lens)): - start_pos = (pooling_metadata.prompt_token_ids[i] == self.config. - hf_config.vision_start_token_id).nonzero()[-1].item() - end_pos = (pooling_metadata.prompt_token_ids[i] == self.config. - hf_config.vision_end_token_id).nonzero()[-1].item() + start_pos = (pooling_metadata.prompt_token_ids[i] == + self.config.hf_config.vision_start_token_id). + nonzero()[-1].item() + end_pos = (pooling_metadata.prompt_token_ids[i] == + self.config.hf_config.vision_end_token_id). + nonzero()[-1].item() seq_start = torch.cumsum( torch.tensor([0] + pooling_metadata.prompt_lens.tolist()), @@ -832,41 +783,18 @@ def forward( dtype=hidden_states.dtype) grid = lambda meta: (self.config.hidden_size, ) - mean_pool_with_position_kernel[grid](hidden_states, output, - seq_start, seq_len, - self.config.hidden_size, - start_pos, end_pos + 1) + if HAS_TRITON: + mean_pool_with_position_kernel[grid](hidden_states, output, + seq_start, seq_len, + self.config.hidden_size, + start_pos, end_pos + 1) + else: + # Fallback to PyTorch implementation if Triton is not available + vision_tokens_range = hidden_states[seq_start + start_pos : seq_start + end_pos + 1] + output = vision_tokens_range.mean(dim=0) pooled_outputs.append(output) return build_output(torch.stack(pooled_outputs)) -if HAS_TRITON: - - @triton.jit - def mean_pool_with_position_kernel( - hidden_states_ptr, - output_ptr, - seq_start, - seq_len, - hidden_size, - pool_start, - pool_end, - BLOCK_SIZE: tl.constexpr, - ): - """Triton kernel to perform mean pooling over a specified token range.""" - pid = tl.program_id(0) - - if pid >= hidden_size: - return - - accumulator = 0.0 - for i in range(pool_start, pool_end): - hidden_val = tl.load(hidden_states_ptr + - (seq_start + i) * hidden_size + pid) - accumulator += hidden_val - - # Store mean pooled result - result = accumulator / (pool_end - pool_start) - tl.store(output_ptr + pid, result) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py index 1d543047cb17..f97420c56f6c 100644 --- a/vllm/model_executor/models/jina_embeddings_v4.py +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -6,7 +6,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingTask +from vllm.model_executor.layers.pooler import Pooler, PoolingTask, VisionPooler # yapf: disable from vllm.model_executor.pooling_metadata import ( PoolingMetadata as V0PoolingMetadata) @@ -32,8 +32,6 @@ class JinaVLPooler(Pooler): - """Vision-aware pooler for Jina V4 with special vision token handling.""" - def __init__(self, vllm_config: VllmConfig): super().__init__() self.vision_pooler = VisionPooler(vllm_config.model_config)