Skip to content

Commit bda9af3

Browse files
authored
do not save past scores if output_scores=False (#610)
1 parent 2017c98 commit bda9af3

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

mamba_ssm/utils/generation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def decode(
132132
vocab_size=None,
133133
cg=False,
134134
enable_timing=False,
135+
output_scores=False,
135136
streamer: Optional[TextStreamer] = None
136137
):
137138
"""Decoding, either greedy or with top-k or top-p sampling.
@@ -218,13 +219,15 @@ def should_stop(current_token, inference_params):
218219
scores, sequences = [], [input_ids]
219220
sequences_cat = input_ids
220221
while not should_stop(sequences[-1], inference_params):
221-
scores.append(get_logits(sequences[-1], inference_params))
222+
logits = get_logits(sequences[-1], inference_params)
223+
if output_scores:
224+
scores.append(logits.clone())
222225
inference_params.seqlen_offset += sequences[-1].shape[1]
223226
if repetition_penalty == 1.0:
224-
sampled_tokens = sample_tokens(scores[-1], inference_params)
227+
sampled_tokens = sample_tokens(logits, inference_params)
225228
else:
226229
logits = modify_logit_for_repetition_penalty(
227-
scores[-1].clone(), sequences_cat, repetition_penalty
230+
logits, sequences_cat, repetition_penalty
228231
)
229232
sampled_tokens = sample_tokens(logits, inference_params)
230233
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
@@ -258,7 +261,7 @@ def generate(
258261
**kwargs,
259262
):
260263
output = decode(
261-
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
264+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
262265
)
263266
if not output_scores:
264267
output.scores = None

0 commit comments

Comments
 (0)