@@ -42,9 +42,9 @@ def _get_inputs(
42
42
prompt ,
43
43
stream = True ,
44
44
sampling_parameters = None ,
45
- output_finish_reason = None ,
46
- output_cumulative_logprob = None ,
47
- output_num_token_ids = None ,
45
+ return_finish_reason = None ,
46
+ return_cumulative_logprob = None ,
47
+ return_num_token_ids = None ,
48
48
):
49
49
inputs = []
50
50
@@ -64,21 +64,21 @@ def _get_inputs(
64
64
)
65
65
)
66
66
67
- if output_finish_reason is not None :
68
- inputs .append (grpcclient .InferInput ("output_finish_reason " , [1 ], "BOOL" ))
69
- inputs [- 1 ].set_data_from_numpy (np .array ([output_finish_reason ], dtype = bool ))
67
+ if return_finish_reason is not None :
68
+ inputs .append (grpcclient .InferInput ("return_finish_reason " , [1 ], "BOOL" ))
69
+ inputs [- 1 ].set_data_from_numpy (np .array ([return_finish_reason ], dtype = bool ))
70
70
71
- if output_cumulative_logprob is not None :
71
+ if return_cumulative_logprob is not None :
72
72
inputs .append (
73
- grpcclient .InferInput ("output_cumulative_logprob " , [1 ], "BOOL" )
73
+ grpcclient .InferInput ("return_cumulative_logprob " , [1 ], "BOOL" )
74
74
)
75
75
inputs [- 1 ].set_data_from_numpy (
76
- np .array ([output_cumulative_logprob ], dtype = bool )
76
+ np .array ([return_cumulative_logprob ], dtype = bool )
77
77
)
78
78
79
- if output_num_token_ids is not None :
80
- inputs .append (grpcclient .InferInput ("output_num_token_ids " , [1 ], "BOOL" ))
81
- inputs [- 1 ].set_data_from_numpy (np .array ([output_num_token_ids ], dtype = bool ))
79
+ if return_num_token_ids is not None :
80
+ inputs .append (grpcclient .InferInput ("return_num_token_ids " , [1 ], "BOOL" ))
81
+ inputs [- 1 ].set_data_from_numpy (np .array ([return_num_token_ids ], dtype = bool ))
82
82
83
83
return inputs
84
84
@@ -104,12 +104,12 @@ def _assert_text_output_valid(self):
104
104
assert len (text_output ) > 0 , "output is empty"
105
105
assert text_output .count (" " ) > 4 , "output is not a sentence"
106
106
107
- def _assert_finish_reason (self , output_finish_reason ):
107
+ def _assert_finish_reason (self , return_finish_reason ):
108
108
for i in range (len (self ._responses )):
109
109
result , error = self ._responses [i ]["result" ], self ._responses [i ]["error" ]
110
110
assert error is None
111
111
finish_reason_np = result .as_numpy (name = "finish_reason" )
112
- if output_finish_reason is None or output_finish_reason == False :
112
+ if return_finish_reason is None or return_finish_reason == False :
113
113
assert finish_reason_np is None
114
114
continue
115
115
finish_reason = finish_reason_np [0 ].decode ("utf-8" )
@@ -118,25 +118,25 @@ def _assert_finish_reason(self, output_finish_reason):
118
118
else :
119
119
assert finish_reason == "length"
120
120
121
- def _assert_cumulative_logprob (self , output_cumulative_logprob ):
121
+ def _assert_cumulative_logprob (self , return_cumulative_logprob ):
122
122
prev_cumulative_logprob = 0.0
123
123
for response in self ._responses :
124
124
result , error = response ["result" ], response ["error" ]
125
125
assert error is None
126
126
cumulative_logprob_np = result .as_numpy (name = "cumulative_logprob" )
127
- if output_cumulative_logprob is None or output_cumulative_logprob == False :
127
+ if return_cumulative_logprob is None or return_cumulative_logprob == False :
128
128
assert cumulative_logprob_np is None
129
129
continue
130
130
cumulative_logprob = cumulative_logprob_np [0 ].astype (float )
131
131
assert cumulative_logprob != prev_cumulative_logprob
132
132
prev_cumulative_logprob = cumulative_logprob
133
133
134
- def _assert_num_token_ids (self , output_num_token_ids ):
134
+ def _assert_num_token_ids (self , return_num_token_ids ):
135
135
for response in self ._responses :
136
136
result , error = response ["result" ], response ["error" ]
137
137
assert error is None
138
138
num_token_ids_np = result .as_numpy (name = "num_token_ids" )
139
- if output_num_token_ids is None or output_num_token_ids == False :
139
+ if return_num_token_ids is None or return_num_token_ids == False :
140
140
assert num_token_ids_np is None
141
141
continue
142
142
num_token_ids = num_token_ids_np [0 ].astype (int )
@@ -160,26 +160,26 @@ def _assert_num_token_ids(self, output_num_token_ids):
160
160
assert num_token_ids >= 0
161
161
162
162
@pytest .mark .parametrize ("stream" , [True , False ])
163
- @pytest .mark .parametrize ("output_finish_reason " , [None , True , False ])
164
- @pytest .mark .parametrize ("output_cumulative_logprob " , [None , True , False ])
165
- @pytest .mark .parametrize ("output_num_token_ids " , [None , True , False ])
163
+ @pytest .mark .parametrize ("return_finish_reason " , [None , True , False ])
164
+ @pytest .mark .parametrize ("return_cumulative_logprob " , [None , True , False ])
165
+ @pytest .mark .parametrize ("return_num_token_ids " , [None , True , False ])
166
166
def test_additional_outputs (
167
167
self ,
168
168
stream ,
169
- output_finish_reason ,
170
- output_cumulative_logprob ,
171
- output_num_token_ids ,
169
+ return_finish_reason ,
170
+ return_cumulative_logprob ,
171
+ return_num_token_ids ,
172
172
):
173
173
inputs = self ._get_inputs (
174
174
self ._prompt ,
175
175
stream = stream ,
176
176
sampling_parameters = self ._sampling_parameters ,
177
- output_finish_reason = output_finish_reason ,
178
- output_cumulative_logprob = output_cumulative_logprob ,
179
- output_num_token_ids = output_num_token_ids ,
177
+ return_finish_reason = return_finish_reason ,
178
+ return_cumulative_logprob = return_cumulative_logprob ,
179
+ return_num_token_ids = return_num_token_ids ,
180
180
)
181
181
self ._llm_infer (inputs )
182
182
self ._assert_text_output_valid ()
183
- self ._assert_finish_reason (output_finish_reason )
184
- self ._assert_cumulative_logprob (output_cumulative_logprob )
185
- self ._assert_num_token_ids (output_num_token_ids )
183
+ self ._assert_finish_reason (return_finish_reason )
184
+ self ._assert_cumulative_logprob (return_cumulative_logprob )
185
+ self ._assert_num_token_ids (return_num_token_ids )
0 commit comments