Skip to content

[Model] Add support for Jina Embeddings V4 #20802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9fbc0e9
feat: jina support
sigridjineth Jul 11, 2025
eea8462
refactor: fail fast
sigridjineth Jul 11, 2025
5e247e9
refactor: exceptions
sigridjineth Jul 11, 2025
9be40b2
refactor: improve jina embeddings v4 model
sigridjineth Jul 11, 2025
64c06c7
refactor: oom
sigridjineth Jul 11, 2025
56b7409
refactor: Validate lengths match
sigridjineth Jul 11, 2025
bef3df2
refactor: normalize
sigridjineth Jul 11, 2025
efa8b04
refactor: normalize
sigridjineth Jul 11, 2025
0fe30f8
refactor: review
sigridjineth Jul 11, 2025
062a156
refactor: prehook commits
sigridjineth Jul 16, 2025
edfe91a
fix: Apply isort formatting to jina_embeddings_v4.py
Jul 16, 2025
5d12bd4
[ci skip-hooks] Formatting attempt(s)
Jul 16, 2025
27b28f7
fix: Resolve yapf/isort conflict with disable comments
Jul 17, 2025
3bdbd17
refactor: accept review
Jul 17, 2025
fafd668
refactor: address review feedback for Jina embeddings V4
Jul 18, 2025
0c3f1bd
refactor: import HAS_TRITON from triton_utils instead of local defini…
Jul 18, 2025
5c45015
refactor: rename example file to follow existing embedding pattern
Jul 18, 2025
9d34781
Merge remote-tracking branch 'origin/main' into jina-support
Jul 18, 2025
8e0578a
refactor: update JinaVLForEmbedding to comply with new pooling archit…
Jul 18, 2025
eb1497e
refactor: use pooler utility functions to avoid duplicate code
Jul 18, 2025
1b4f405
refactor: address maintainer review comments for JinaVLPooler
Jul 18, 2025
702fd16
perf: optimize vision token detection using torch.isin
Jul 18, 2025
5114a3c
fix: introducing dedicated VisionPooler class
Jul 18, 2025
6b501b2
feat: add vision pooling support for jina embeddings v4
Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions examples/offline_inference/vision_language_embedding.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my bad. I meant vision_language_pooling.py, we just renamed that file from vision_language_embedding.py recently. 😅

Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using Jina Embeddings V4 with vLLM for multimodal embeddings.

This example demonstrates:
1. Text-only embeddings
2. Image-only embeddings
3. Mixed text and image embeddings
"""

import torch

from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs.data import TextPrompt
from vllm.multimodal.utils import fetch_image


def get_embeddings(outputs):
"""Extract and normalize embeddings from model outputs."""
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653

embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# For vision inputs, extract only vision token embeddings
img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# For text-only inputs, use all token embeddings
embeddings_tensor = output.outputs.data.detach().clone()

# Pool and normalize embeddings
pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings


def main():
# Initialize the model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-retrieval",
task="embed",
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
dtype="float16",
)

# Example 1: Text-only embeddings
print("=== Text Embeddings ===")
query = "Overview of climate change impacts on coastal cities"
query_prompt = TextPrompt(prompt=f"Query: {query}")

passage = """The impacts of climate change on coastal cities are significant
and multifaceted. Rising sea levels threaten infrastructure, while increased
storm intensity poses risks to populations and economies."""
passage_prompt = TextPrompt(prompt=f"Passage: {passage}")

# Generate embeddings
text_outputs = model.encode([query_prompt, passage_prompt])
text_embeddings = get_embeddings(text_outputs)

# Calculate similarity
similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item()
print(f"Query: {query[:50]}...")
print(f"Passage: {passage[:50]}...")
print(f"Similarity: {similarity:.4f}\n")

# Example 2: Image embeddings
print("=== Image Embeddings ===")
# Fetch sample images
image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"

image1 = fetch_image(image1_url)
image2 = fetch_image(image2_url)

# Create image prompts with the required format
image1_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
multi_modal_data={"image": image1},
)

image2_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
multi_modal_data={"image": image2},
)

# Generate embeddings
image_outputs = model.encode([image1_prompt, image2_prompt])
image_embeddings = get_embeddings(image_outputs)

# Calculate similarity
similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item()
print(f"Image 1: {image1_url.split('/')[-1]}")
print(f"Image 2: {image2_url.split('/')[-1]}")
print(f"Similarity: {similarity:.4f}\n")

# Example 3: Cross-modal similarity (text vs image)
print("=== Cross-modal Similarity ===")
query = "scientific paper with markdown formatting"
query_prompt = TextPrompt(prompt=f"Query: {query}")

# Generate embeddings for text query and second image
cross_outputs = model.encode([query_prompt, image2_prompt])
cross_embeddings = get_embeddings(cross_outputs)

# Calculate cross-modal similarity
similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item()
print(f"Text query: {query}")
print(f"Image: {image2_url.split('/')[-1]}")
print(f"Cross-modal similarity: {similarity:.4f}")


if __name__ == "__main__":
main()
Loading