@@ -37,13 +37,21 @@ class TestAdditionalOutputs:
37
37
_sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
38
38
_prompt = "In this example,"
39
39
40
+ def _get_sampling_parameters (self , logprobs = None ):
41
+ sampling_parameters = self ._sampling_parameters .copy ()
42
+ if logprobs is not None :
43
+ sampling_parameters ["logprobs" ] = logprobs
44
+ return sampling_parameters
45
+
40
46
def _get_inputs (
41
47
self ,
42
48
prompt ,
43
49
stream = True ,
44
50
sampling_parameters = None ,
45
51
return_finish_reason = None ,
46
52
return_cumulative_logprob = None ,
53
+ return_logprobs = None ,
54
+ return_num_input_tokens = None ,
47
55
return_num_output_tokens = None ,
48
56
):
49
57
inputs = []
@@ -76,6 +84,16 @@ def _get_inputs(
76
84
np .array ([return_cumulative_logprob ], dtype = bool )
77
85
)
78
86
87
+ if return_logprobs is not None :
88
+ inputs .append (grpcclient .InferInput ("return_logprobs" , [1 ], "BOOL" ))
89
+ inputs [- 1 ].set_data_from_numpy (np .array ([return_logprobs ], dtype = bool ))
90
+
91
+ if return_num_input_tokens is not None :
92
+ inputs .append (grpcclient .InferInput ("return_num_input_tokens" , [1 ], "BOOL" ))
93
+ inputs [- 1 ].set_data_from_numpy (
94
+ np .array ([return_num_input_tokens ], dtype = bool )
95
+ )
96
+
79
97
if return_num_output_tokens is not None :
80
98
inputs .append (
81
99
grpcclient .InferInput ("return_num_output_tokens" , [1 ], "BOOL" )
@@ -89,12 +107,12 @@ def _get_inputs(
89
107
def _callback (self , result , error ):
90
108
self ._responses .append ({"result" : result , "error" : error })
91
109
92
- def _llm_infer (self , inputs ):
110
+ def _llm_infer (self , inputs , sampling_parameters ):
93
111
self ._responses = []
94
112
with grpcclient .InferenceServerClient (self ._grpc_url ) as client :
95
113
client .start_stream (self ._callback )
96
114
client .async_stream_infer (
97
- self ._model_name , inputs = inputs , parameters = self . _sampling_parameters
115
+ self ._model_name , inputs = inputs , parameters = sampling_parameters
98
116
)
99
117
client .stop_stream ()
100
118
assert len (self ._responses ) > 0
@@ -135,6 +153,63 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
135
153
assert cumulative_logprob != prev_cumulative_logprob
136
154
prev_cumulative_logprob = cumulative_logprob
137
155
156
+ def _assert_logprobs (
157
+ self , stream , sampling_parameters , return_logprobs , return_num_output_tokens
158
+ ):
159
+ for response in self ._responses :
160
+ result , error = response ["result" ], response ["error" ]
161
+ assert error is None
162
+ logprobs_np = result .as_numpy (name = "logprobs" )
163
+ if return_logprobs is None or return_logprobs == False :
164
+ assert logprobs_np is None
165
+ continue
166
+ logprobs = json .loads (logprobs_np [0 ].decode ("utf-8" ))
167
+ if "logprobs" not in sampling_parameters :
168
+ assert logprobs is None
169
+ continue
170
+ assert isinstance (logprobs , list )
171
+ assert len (logprobs ) >= 1
172
+ if return_num_output_tokens == True :
173
+ num_output_tokens = result .as_numpy (name = "num_output_tokens" )[0 ].astype (
174
+ int
175
+ )
176
+ assert len (logprobs ) == num_output_tokens
177
+ text_output_logprobs = ""
178
+ for logprobs_d in logprobs :
179
+ assert isinstance (logprobs_d , dict )
180
+ assert len (logprobs_d ) >= 1
181
+ assert len (logprobs_d ) <= sampling_parameters ["logprobs" ] + 1
182
+ rank_one_found = False
183
+ for token_id , logprob_d in logprobs_d .items ():
184
+ assert isinstance (token_id , str )
185
+ assert len (logprob_d ) == 3
186
+ assert isinstance (logprob_d ["logprob" ], float )
187
+ assert isinstance (logprob_d ["rank" ], int )
188
+ assert isinstance (logprob_d ["decoded_token" ], str )
189
+ if logprob_d ["rank" ] == 1 :
190
+ assert not rank_one_found
191
+ rank_one_found = True
192
+ text_output_logprobs += logprob_d ["decoded_token" ]
193
+ assert rank_one_found
194
+ text_output = result .as_numpy (name = "text_output" )[0 ].decode ("utf-8" )
195
+ if not stream :
196
+ # given exclude_input_in_output is not set, prepend_input is True if not
197
+ # streaming and False if streaming
198
+ text_output_logprobs = self ._prompt + text_output_logprobs
199
+ assert text_output_logprobs == text_output
200
+
201
+ def _assert_num_input_tokens (self , return_num_input_tokens ):
202
+ for response in self ._responses :
203
+ result , error = response ["result" ], response ["error" ]
204
+ assert error is None
205
+ num_input_tokens_np = result .as_numpy (name = "num_input_tokens" )
206
+ if return_num_input_tokens is None or return_num_input_tokens == False :
207
+ assert num_input_tokens_np is None
208
+ continue
209
+ num_input_tokens = num_input_tokens_np .astype (int )
210
+ assert num_input_tokens > 0
211
+ assert num_input_tokens <= len (self ._prompt )
212
+
138
213
def _assert_num_output_tokens (self , return_num_output_tokens ):
139
214
for response in self ._responses :
140
215
result , error = response ["result" ], response ["error" ]
@@ -144,46 +219,42 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
144
219
assert num_output_tokens_np is None
145
220
continue
146
221
num_output_tokens = num_output_tokens_np [0 ].astype (int )
147
- # TODO: vLLM may return token ids identical to the previous one when
148
- # streaming, for example:
149
- #
150
- # prev: None
151
- # curr: text=' the', token_ids=array('l', [5])
152
- #
153
- # prev: text=' the', token_ids=array('l', [5, 1385])
154
- # curr: text=' the term', token_ids=array('l', [5, 1385])
155
- #
156
- # prev: text=' the term', token_ids=array('l', [5, 1385, 44])
157
- # curr: text=' the term', token_ids=array('l', [5, 1385, 44])
158
- #
159
- # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
160
- # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
161
- #
162
- # If this is no longer the case in a future release, change the assert
163
- # to assert num_output_tokens > 0.
164
- assert num_output_tokens >= 0
222
+ assert num_output_tokens > 0
165
223
166
224
@pytest .mark .parametrize ("stream" , [True , False ])
167
225
@pytest .mark .parametrize ("return_finish_reason" , [None , True , False ])
168
226
@pytest .mark .parametrize ("return_cumulative_logprob" , [None , True , False ])
227
+ @pytest .mark .parametrize ("logprobs" , [None , 0 , 2 ])
228
+ @pytest .mark .parametrize ("return_logprobs" , [None , True , False ])
229
+ @pytest .mark .parametrize ("return_num_input_tokens" , [None , True , False ])
169
230
@pytest .mark .parametrize ("return_num_output_tokens" , [None , True , False ])
170
231
def test_additional_outputs (
171
232
self ,
172
233
stream ,
173
234
return_finish_reason ,
174
235
return_cumulative_logprob ,
236
+ logprobs ,
237
+ return_logprobs ,
238
+ return_num_input_tokens ,
175
239
return_num_output_tokens ,
176
240
):
241
+ sampling_parameters = self ._get_sampling_parameters (logprobs = logprobs )
177
242
inputs = self ._get_inputs (
178
243
self ._prompt ,
179
244
stream = stream ,
180
- sampling_parameters = self . _sampling_parameters ,
245
+ sampling_parameters = sampling_parameters ,
181
246
return_finish_reason = return_finish_reason ,
182
247
return_cumulative_logprob = return_cumulative_logprob ,
248
+ return_logprobs = return_logprobs ,
249
+ return_num_input_tokens = return_num_input_tokens ,
183
250
return_num_output_tokens = return_num_output_tokens ,
184
251
)
185
- self ._llm_infer (inputs )
252
+ self ._llm_infer (inputs , sampling_parameters )
186
253
self ._assert_text_output_valid ()
187
254
self ._assert_finish_reason (return_finish_reason )
188
255
self ._assert_cumulative_logprob (return_cumulative_logprob )
256
+ self ._assert_logprobs (
257
+ stream , sampling_parameters , return_logprobs , return_num_output_tokens
258
+ )
259
+ self ._assert_num_input_tokens (return_num_input_tokens )
189
260
self ._assert_num_output_tokens (return_num_output_tokens )
0 commit comments