Skip to content

Commit 062a156

Browse files
sigridjinethSigrid Jin (Sionic AI)
authored andcommitted
refactor: prehook commits
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 0fe30f8 commit 062a156

File tree

5 files changed

+299
-283
lines changed

5 files changed

+299
-283
lines changed

benchmarks/jina_embeddings_v4_validation.py

Lines changed: 89 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import argparse
1111
import time
12-
from typing import List, Tuple
1312

1413
import numpy as np
1514
import torch
@@ -18,108 +17,112 @@
1817

1918
from vllm import LLM
2019
from vllm.config import PoolerConfig
20+
from vllm.inputs.data import TextPrompt
2121

2222
# Vision token IDs
2323
VISION_START_TOKEN_ID = 151652
2424
VISION_END_TOKEN_ID = 151653
25-
from vllm.inputs.data import TextPrompt
2625

2726

28-
def create_test_cases() -> List[Tuple[str, str, any]]:
27+
def create_test_cases() -> list[tuple[str, str, any]]:
2928
"""Create comprehensive test cases for validation."""
3029
test_cases = []
31-
30+
3231
# Text-only test cases
33-
test_cases.extend([
34-
("text", "Query: What is artificial intelligence?", None),
35-
("text", "Passage: AI is a field of computer science focusing on creating intelligent machines.", None),
36-
("text", "Query: 你好世界", None), # Chinese text
37-
("text", "Passage: " + " ".join(["word"] * 100), None), # Long text
38-
])
39-
32+
test_cases.extend(
33+
[
34+
("text", "Query: What is artificial intelligence?", None),
35+
(
36+
"text",
37+
"Passage: AI is a field of computer science focusing on "
38+
"creating intelligent machines.",
39+
None,
40+
),
41+
("text", "Query: 你好世界", None), # Chinese text
42+
("text", "Passage: " + " ".join(["word"] * 100), None), # Long text
43+
]
44+
)
45+
4046
# Image test cases
4147
for color in ["red", "green", "blue"]:
42-
img = Image.new('RGB', (224, 224), color=color)
48+
img = Image.new("RGB", (224, 224), color=color)
4349
test_cases.append(("image", f"{color} image", img))
44-
50+
4551
# Complex image
46-
complex_img = Image.new('RGB', (224, 224))
52+
complex_img = Image.new("RGB", (224, 224))
4753
pixels = complex_img.load()
4854
for i in range(224):
4955
for j in range(224):
50-
pixels[i, j] = (i % 256, j % 256, (i+j) % 256)
56+
pixels[i, j] = (i % 256, j % 256, (i + j) % 256)
5157
test_cases.append(("image", "complex pattern", complex_img))
52-
58+
5359
return test_cases
5460

5561

5662
def compute_hf_embeddings(
57-
model_name: str,
58-
test_cases: List[Tuple[str, str, any]]
59-
) -> List[torch.Tensor]:
63+
model_name: str, test_cases: list[tuple[str, str, any]]
64+
) -> list[torch.Tensor]:
6065
"""Compute embeddings using HuggingFace implementation."""
6166
print("Loading HuggingFace model...")
62-
model = AutoModel.from_pretrained(
63-
model_name,
64-
trust_remote_code=True,
65-
torch_dtype=torch.float16
66-
).cuda().eval()
67-
68-
processor = AutoProcessor.from_pretrained(
69-
model_name,
70-
trust_remote_code=True
67+
model = (
68+
AutoModel.from_pretrained(
69+
model_name, trust_remote_code=True, torch_dtype=torch.float16
70+
)
71+
.cuda()
72+
.eval()
7173
)
72-
74+
75+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
76+
7377
embeddings = []
74-
78+
7579
print("Computing HuggingFace embeddings...")
7680
start_time = time.time()
77-
81+
7882
for case_type, text, image in test_cases:
7983
if case_type == "text":
8084
inputs = processor(text=text, return_tensors="pt").to("cuda")
8185
else: # image
8286
inputs = processor(
83-
text="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
87+
text="<|im_start|>user\n<|vision_start|><|image_pad|>"
88+
"<|vision_end|>Describe the image.<|im_end|>\n",
8489
images=image,
85-
return_tensors="pt"
90+
return_tensors="pt",
8691
).to("cuda")
87-
92+
8893
with torch.no_grad():
8994
outputs = model(**inputs)
9095
# Extract embeddings based on model output structure
91-
if hasattr(outputs, 'embeddings'):
96+
if hasattr(outputs, "embeddings"):
9297
embedding = outputs.embeddings[0]
9398
else:
9499
# Fallback to last hidden state with custom pooling
95100
hidden_states = outputs.last_hidden_state[0]
96-
101+
97102
# Apply token-type-aware pooling
98-
input_ids = inputs['input_ids'][0]
99-
vision_mask = (
100-
(input_ids >= VISION_START_TOKEN_ID) &
101-
(input_ids <= VISION_END_TOKEN_ID)
103+
input_ids = inputs["input_ids"][0]
104+
vision_mask = (input_ids >= VISION_START_TOKEN_ID) & (
105+
input_ids <= VISION_END_TOKEN_ID
102106
)
103-
107+
104108
if vision_mask.any():
105109
embedding = hidden_states[vision_mask].mean(dim=0)
106110
else:
107111
embedding = hidden_states.mean(dim=0)
108-
112+
109113
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
110-
114+
111115
embeddings.append(embedding.cpu())
112-
116+
113117
hf_time = time.time() - start_time
114118
print(f"HuggingFace processing time: {hf_time:.2f}s")
115-
119+
116120
return embeddings
117121

118122

119123
def compute_vllm_embeddings(
120-
model_name: str,
121-
test_cases: List[Tuple[str, str, any]]
122-
) -> List[torch.Tensor]:
124+
model_name: str, test_cases: list[tuple[str, str, any]]
125+
) -> list[torch.Tensor]:
123126
"""Compute embeddings using vLLM implementation."""
124127
print("\nLoading vLLM model...")
125128
model = LLM(
@@ -128,93 +131,93 @@ def compute_vllm_embeddings(
128131
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
129132
dtype="float16",
130133
)
131-
134+
132135
embeddings = []
133136
prompts = []
134-
137+
135138
# Prepare prompts
136139
for case_type, text, image in test_cases:
137140
if case_type == "text":
138141
prompt = TextPrompt(prompt=text)
139142
else: # image
140143
prompt = TextPrompt(
141-
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
144+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
145+
"<|vision_end|>Describe the image.<|im_end|>\n",
142146
multi_modal_data={"image": image},
143147
)
144148
prompts.append(prompt)
145-
149+
146150
print("Computing vLLM embeddings...")
147151
start_time = time.time()
148-
152+
149153
# Process all at once for better performance
150154
outputs = model.encode(prompts)
151-
155+
152156
for output in outputs:
153157
# Extract based on token type
154158
if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID
155159
img_start = output.prompt_token_ids.index(151652)
156160
img_end = output.prompt_token_ids.index(151653)
157-
embedding_data = output.outputs.data[img_start:img_end + 1]
161+
embedding_data = output.outputs.data[img_start : img_end + 1]
158162
else:
159163
embedding_data = output.outputs.data
160-
164+
161165
# Pool and normalize
162166
pooled = embedding_data.mean(dim=0, dtype=torch.float32)
163167
normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1)
164168
embeddings.append(normalized.cpu())
165-
169+
166170
vllm_time = time.time() - start_time
167171
print(f"vLLM processing time: {vllm_time:.2f}s")
168-
172+
169173
return embeddings
170174

171175

172176
def compare_embeddings(
173-
hf_embeddings: List[torch.Tensor],
174-
vllm_embeddings: List[torch.Tensor],
175-
test_cases: List[Tuple[str, str, any]]
177+
hf_embeddings: list[torch.Tensor],
178+
vllm_embeddings: list[torch.Tensor],
179+
test_cases: list[tuple[str, str, any]],
176180
) -> None:
177181
"""Compare embeddings and report differences."""
178-
print("\n" + "="*60)
182+
print("\n" + "=" * 60)
179183
print("EMBEDDING COMPARISON RESULTS")
180-
print("="*60)
181-
184+
print("=" * 60)
185+
182186
similarities = []
183187
max_diffs = []
184-
188+
185189
for i, (case_type, desc, _) in enumerate(test_cases):
186190
hf_emb = hf_embeddings[i]
187191
vllm_emb = vllm_embeddings[i]
188-
192+
189193
# Compute cosine similarity
190194
similarity = torch.nn.functional.cosine_similarity(
191-
hf_emb.unsqueeze(0),
192-
vllm_emb.unsqueeze(0)
195+
hf_emb.unsqueeze(0), vllm_emb.unsqueeze(0)
193196
).item()
194-
197+
195198
# Compute max absolute difference
196199
max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item()
197-
200+
198201
similarities.append(similarity)
199202
max_diffs.append(max_diff)
200-
201-
print(f"\nTest case {i+1}: {case_type} - {desc[:50]}...")
203+
204+
print(f"\nTest case {i + 1}: {case_type} - {desc[:50]}...")
202205
print(f" Cosine similarity: {similarity:.6f}")
203206
print(f" Max absolute diff: {max_diff:.6f}")
204207
print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}")
205-
208+
206209
# Flag significant differences
207210
if similarity < 0.99:
208-
print(f" ⚠️ WARNING: Low similarity detected!")
209-
211+
print(" ⚠️ WARNING: Low similarity detected!")
212+
210213
# Summary statistics
211-
print("\n" + "-"*60)
214+
print("\n" + "-" * 60)
212215
print("SUMMARY STATISTICS")
213-
print("-"*60)
216+
print("-" * 60)
214217
print(f"Average cosine similarity: {np.mean(similarities):.6f}")
215218
print(f"Min cosine similarity: {np.min(similarities):.6f}")
216219
print(f"Max absolute difference: {np.max(max_diffs):.6f}")
217-
220+
218221
# Overall assessment
219222
if np.min(similarities) > 0.99:
220223
print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace")
@@ -230,27 +233,27 @@ def main():
230233
"--model",
231234
type=str,
232235
default="jinaai/jina-embeddings-v4-vllm-retrieval",
233-
help="Model name to test"
236+
help="Model name to test",
234237
)
235238
parser.add_argument(
236239
"--skip-hf",
237240
action="store_true",
238-
help="Skip HuggingFace comparison (for performance testing only)"
241+
help="Skip HuggingFace comparison (for performance testing only)",
239242
)
240-
243+
241244
args = parser.parse_args()
242-
245+
243246
# Create test cases
244247
test_cases = create_test_cases()
245248
print(f"Created {len(test_cases)} test cases")
246-
249+
247250
# Compute vLLM embeddings
248251
vllm_embeddings = compute_vllm_embeddings(args.model, test_cases)
249-
252+
250253
if not args.skip_hf:
251254
# Compute HuggingFace embeddings
252255
hf_embeddings = compute_hf_embeddings(args.model, test_cases)
253-
256+
254257
# Compare results
255258
compare_embeddings(hf_embeddings, vllm_embeddings, test_cases)
256259
else:
@@ -259,4 +262,4 @@ def main():
259262

260263

261264
if __name__ == "__main__":
262-
main()
265+
main()

0 commit comments

Comments
 (0)