@@ -132,6 +132,7 @@ def decode(
132
132
vocab_size = None ,
133
133
cg = False ,
134
134
enable_timing = False ,
135
+ output_scores = False ,
135
136
streamer : Optional [TextStreamer ] = None
136
137
):
137
138
"""Decoding, either greedy or with top-k or top-p sampling.
@@ -218,13 +219,15 @@ def should_stop(current_token, inference_params):
218
219
scores , sequences = [], [input_ids ]
219
220
sequences_cat = input_ids
220
221
while not should_stop (sequences [- 1 ], inference_params ):
221
- scores .append (get_logits (sequences [- 1 ], inference_params ))
222
+ logits = get_logits (sequences [- 1 ], inference_params )
223
+ if output_scores :
224
+ scores .append (logits .clone ())
222
225
inference_params .seqlen_offset += sequences [- 1 ].shape [1 ]
223
226
if repetition_penalty == 1.0 :
224
- sampled_tokens = sample_tokens (scores [ - 1 ] , inference_params )
227
+ sampled_tokens = sample_tokens (logits , inference_params )
225
228
else :
226
229
logits = modify_logit_for_repetition_penalty (
227
- scores [ - 1 ]. clone () , sequences_cat , repetition_penalty
230
+ logits , sequences_cat , repetition_penalty
228
231
)
229
232
sampled_tokens = sample_tokens (logits , inference_params )
230
233
sequences_cat = torch .cat ([sequences_cat , sampled_tokens ], dim = 1 )
@@ -258,7 +261,7 @@ def generate(
258
261
** kwargs ,
259
262
):
260
263
output = decode (
261
- input_ids , self , max_length , top_k = top_k , top_p = top_p , min_p = min_p , temperature = temperature , ** kwargs
264
+ input_ids , self , max_length , top_k = top_k , top_p = top_p , min_p = min_p , temperature = temperature , output_scores = output_scores , ** kwargs
262
265
)
263
266
if not output_scores :
264
267
output .scores = None
0 commit comments