9
9
10
10
import argparse
11
11
import time
12
- from typing import List , Tuple
13
12
14
13
import numpy as np
15
14
import torch
18
17
19
18
from vllm import LLM
20
19
from vllm .config import PoolerConfig
20
+ from vllm .inputs .data import TextPrompt
21
21
22
22
# Vision token IDs
23
23
VISION_START_TOKEN_ID = 151652
24
24
VISION_END_TOKEN_ID = 151653
25
- from vllm .inputs .data import TextPrompt
26
25
27
26
28
- def create_test_cases () -> List [ Tuple [str , str , any ]]:
27
+ def create_test_cases () -> list [ tuple [str , str , any ]]:
29
28
"""Create comprehensive test cases for validation."""
30
29
test_cases = []
31
-
30
+
32
31
# Text-only test cases
33
- test_cases .extend ([
34
- ("text" , "Query: What is artificial intelligence?" , None ),
35
- ("text" , "Passage: AI is a field of computer science focusing on creating intelligent machines." , None ),
36
- ("text" , "Query: 你好世界" , None ), # Chinese text
37
- ("text" , "Passage: " + " " .join (["word" ] * 100 ), None ), # Long text
38
- ])
39
-
32
+ test_cases .extend (
33
+ [
34
+ ("text" , "Query: What is artificial intelligence?" , None ),
35
+ (
36
+ "text" ,
37
+ "Passage: AI is a field of computer science focusing on "
38
+ "creating intelligent machines." ,
39
+ None ,
40
+ ),
41
+ ("text" , "Query: 你好世界" , None ), # Chinese text
42
+ ("text" , "Passage: " + " " .join (["word" ] * 100 ), None ), # Long text
43
+ ]
44
+ )
45
+
40
46
# Image test cases
41
47
for color in ["red" , "green" , "blue" ]:
42
- img = Image .new (' RGB' , (224 , 224 ), color = color )
48
+ img = Image .new (" RGB" , (224 , 224 ), color = color )
43
49
test_cases .append (("image" , f"{ color } image" , img ))
44
-
50
+
45
51
# Complex image
46
- complex_img = Image .new (' RGB' , (224 , 224 ))
52
+ complex_img = Image .new (" RGB" , (224 , 224 ))
47
53
pixels = complex_img .load ()
48
54
for i in range (224 ):
49
55
for j in range (224 ):
50
- pixels [i , j ] = (i % 256 , j % 256 , (i + j ) % 256 )
56
+ pixels [i , j ] = (i % 256 , j % 256 , (i + j ) % 256 )
51
57
test_cases .append (("image" , "complex pattern" , complex_img ))
52
-
58
+
53
59
return test_cases
54
60
55
61
56
62
def compute_hf_embeddings (
57
- model_name : str ,
58
- test_cases : List [Tuple [str , str , any ]]
59
- ) -> List [torch .Tensor ]:
63
+ model_name : str , test_cases : list [tuple [str , str , any ]]
64
+ ) -> list [torch .Tensor ]:
60
65
"""Compute embeddings using HuggingFace implementation."""
61
66
print ("Loading HuggingFace model..." )
62
- model = AutoModel .from_pretrained (
63
- model_name ,
64
- trust_remote_code = True ,
65
- torch_dtype = torch .float16
66
- ).cuda ().eval ()
67
-
68
- processor = AutoProcessor .from_pretrained (
69
- model_name ,
70
- trust_remote_code = True
67
+ model = (
68
+ AutoModel .from_pretrained (
69
+ model_name , trust_remote_code = True , torch_dtype = torch .float16
70
+ )
71
+ .cuda ()
72
+ .eval ()
71
73
)
72
-
74
+
75
+ processor = AutoProcessor .from_pretrained (model_name , trust_remote_code = True )
76
+
73
77
embeddings = []
74
-
78
+
75
79
print ("Computing HuggingFace embeddings..." )
76
80
start_time = time .time ()
77
-
81
+
78
82
for case_type , text , image in test_cases :
79
83
if case_type == "text" :
80
84
inputs = processor (text = text , return_tensors = "pt" ).to ("cuda" )
81
85
else : # image
82
86
inputs = processor (
83
- text = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
87
+ text = "<|im_start|>user\n <|vision_start|><|image_pad|>"
88
+ "<|vision_end|>Describe the image.<|im_end|>\n " ,
84
89
images = image ,
85
- return_tensors = "pt"
90
+ return_tensors = "pt" ,
86
91
).to ("cuda" )
87
-
92
+
88
93
with torch .no_grad ():
89
94
outputs = model (** inputs )
90
95
# Extract embeddings based on model output structure
91
- if hasattr (outputs , ' embeddings' ):
96
+ if hasattr (outputs , " embeddings" ):
92
97
embedding = outputs .embeddings [0 ]
93
98
else :
94
99
# Fallback to last hidden state with custom pooling
95
100
hidden_states = outputs .last_hidden_state [0 ]
96
-
101
+
97
102
# Apply token-type-aware pooling
98
- input_ids = inputs ['input_ids' ][0 ]
99
- vision_mask = (
100
- (input_ids >= VISION_START_TOKEN_ID ) &
101
- (input_ids <= VISION_END_TOKEN_ID )
103
+ input_ids = inputs ["input_ids" ][0 ]
104
+ vision_mask = (input_ids >= VISION_START_TOKEN_ID ) & (
105
+ input_ids <= VISION_END_TOKEN_ID
102
106
)
103
-
107
+
104
108
if vision_mask .any ():
105
109
embedding = hidden_states [vision_mask ].mean (dim = 0 )
106
110
else :
107
111
embedding = hidden_states .mean (dim = 0 )
108
-
112
+
109
113
embedding = torch .nn .functional .normalize (embedding , p = 2 , dim = - 1 )
110
-
114
+
111
115
embeddings .append (embedding .cpu ())
112
-
116
+
113
117
hf_time = time .time () - start_time
114
118
print (f"HuggingFace processing time: { hf_time :.2f} s" )
115
-
119
+
116
120
return embeddings
117
121
118
122
119
123
def compute_vllm_embeddings (
120
- model_name : str ,
121
- test_cases : List [Tuple [str , str , any ]]
122
- ) -> List [torch .Tensor ]:
124
+ model_name : str , test_cases : list [tuple [str , str , any ]]
125
+ ) -> list [torch .Tensor ]:
123
126
"""Compute embeddings using vLLM implementation."""
124
127
print ("\n Loading vLLM model..." )
125
128
model = LLM (
@@ -128,93 +131,93 @@ def compute_vllm_embeddings(
128
131
override_pooler_config = PoolerConfig (pooling_type = "ALL" , normalize = False ),
129
132
dtype = "float16" ,
130
133
)
131
-
134
+
132
135
embeddings = []
133
136
prompts = []
134
-
137
+
135
138
# Prepare prompts
136
139
for case_type , text , image in test_cases :
137
140
if case_type == "text" :
138
141
prompt = TextPrompt (prompt = text )
139
142
else : # image
140
143
prompt = TextPrompt (
141
- prompt = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
144
+ prompt = "<|im_start|>user\n <|vision_start|><|image_pad|>"
145
+ "<|vision_end|>Describe the image.<|im_end|>\n " ,
142
146
multi_modal_data = {"image" : image },
143
147
)
144
148
prompts .append (prompt )
145
-
149
+
146
150
print ("Computing vLLM embeddings..." )
147
151
start_time = time .time ()
148
-
152
+
149
153
# Process all at once for better performance
150
154
outputs = model .encode (prompts )
151
-
155
+
152
156
for output in outputs :
153
157
# Extract based on token type
154
158
if 151652 in output .prompt_token_ids : # VISION_START_TOKEN_ID
155
159
img_start = output .prompt_token_ids .index (151652 )
156
160
img_end = output .prompt_token_ids .index (151653 )
157
- embedding_data = output .outputs .data [img_start : img_end + 1 ]
161
+ embedding_data = output .outputs .data [img_start : img_end + 1 ]
158
162
else :
159
163
embedding_data = output .outputs .data
160
-
164
+
161
165
# Pool and normalize
162
166
pooled = embedding_data .mean (dim = 0 , dtype = torch .float32 )
163
167
normalized = torch .nn .functional .normalize (pooled , p = 2 , dim = - 1 )
164
168
embeddings .append (normalized .cpu ())
165
-
169
+
166
170
vllm_time = time .time () - start_time
167
171
print (f"vLLM processing time: { vllm_time :.2f} s" )
168
-
172
+
169
173
return embeddings
170
174
171
175
172
176
def compare_embeddings (
173
- hf_embeddings : List [torch .Tensor ],
174
- vllm_embeddings : List [torch .Tensor ],
175
- test_cases : List [ Tuple [str , str , any ]]
177
+ hf_embeddings : list [torch .Tensor ],
178
+ vllm_embeddings : list [torch .Tensor ],
179
+ test_cases : list [ tuple [str , str , any ]],
176
180
) -> None :
177
181
"""Compare embeddings and report differences."""
178
- print ("\n " + "=" * 60 )
182
+ print ("\n " + "=" * 60 )
179
183
print ("EMBEDDING COMPARISON RESULTS" )
180
- print ("=" * 60 )
181
-
184
+ print ("=" * 60 )
185
+
182
186
similarities = []
183
187
max_diffs = []
184
-
188
+
185
189
for i , (case_type , desc , _ ) in enumerate (test_cases ):
186
190
hf_emb = hf_embeddings [i ]
187
191
vllm_emb = vllm_embeddings [i ]
188
-
192
+
189
193
# Compute cosine similarity
190
194
similarity = torch .nn .functional .cosine_similarity (
191
- hf_emb .unsqueeze (0 ),
192
- vllm_emb .unsqueeze (0 )
195
+ hf_emb .unsqueeze (0 ), vllm_emb .unsqueeze (0 )
193
196
).item ()
194
-
197
+
195
198
# Compute max absolute difference
196
199
max_diff = torch .max (torch .abs (hf_emb - vllm_emb )).item ()
197
-
200
+
198
201
similarities .append (similarity )
199
202
max_diffs .append (max_diff )
200
-
201
- print (f"\n Test case { i + 1 } : { case_type } - { desc [:50 ]} ..." )
203
+
204
+ print (f"\n Test case { i + 1 } : { case_type } - { desc [:50 ]} ..." )
202
205
print (f" Cosine similarity: { similarity :.6f} " )
203
206
print (f" Max absolute diff: { max_diff :.6f} " )
204
207
print (f" HF norm: { hf_emb .norm ():.6f} , vLLM norm: { vllm_emb .norm ():.6f} " )
205
-
208
+
206
209
# Flag significant differences
207
210
if similarity < 0.99 :
208
- print (f " ⚠️ WARNING: Low similarity detected!" )
209
-
211
+ print (" ⚠️ WARNING: Low similarity detected!" )
212
+
210
213
# Summary statistics
211
- print ("\n " + "-" * 60 )
214
+ print ("\n " + "-" * 60 )
212
215
print ("SUMMARY STATISTICS" )
213
- print ("-" * 60 )
216
+ print ("-" * 60 )
214
217
print (f"Average cosine similarity: { np .mean (similarities ):.6f} " )
215
218
print (f"Min cosine similarity: { np .min (similarities ):.6f} " )
216
219
print (f"Max absolute difference: { np .max (max_diffs ):.6f} " )
217
-
220
+
218
221
# Overall assessment
219
222
if np .min (similarities ) > 0.99 :
220
223
print ("\n ✅ VALIDATION PASSED: vLLM implementation matches HuggingFace" )
@@ -230,27 +233,27 @@ def main():
230
233
"--model" ,
231
234
type = str ,
232
235
default = "jinaai/jina-embeddings-v4-vllm-retrieval" ,
233
- help = "Model name to test"
236
+ help = "Model name to test" ,
234
237
)
235
238
parser .add_argument (
236
239
"--skip-hf" ,
237
240
action = "store_true" ,
238
- help = "Skip HuggingFace comparison (for performance testing only)"
241
+ help = "Skip HuggingFace comparison (for performance testing only)" ,
239
242
)
240
-
243
+
241
244
args = parser .parse_args ()
242
-
245
+
243
246
# Create test cases
244
247
test_cases = create_test_cases ()
245
248
print (f"Created { len (test_cases )} test cases" )
246
-
249
+
247
250
# Compute vLLM embeddings
248
251
vllm_embeddings = compute_vllm_embeddings (args .model , test_cases )
249
-
252
+
250
253
if not args .skip_hf :
251
254
# Compute HuggingFace embeddings
252
255
hf_embeddings = compute_hf_embeddings (args .model , test_cases )
253
-
256
+
254
257
# Compare results
255
258
compare_embeddings (hf_embeddings , vllm_embeddings , test_cases )
256
259
else :
@@ -259,4 +262,4 @@ def main():
259
262
260
263
261
264
if __name__ == "__main__" :
262
- main ()
265
+ main ()
0 commit comments