Skip to content

Commit 8f42a5e

Browse files
Fix typo in min_p sampling (#385)
1 parent 219f03c commit 8f42a5e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/utils/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
105105
logits_top = logits.clone()
106106
max_prob = logits_top[..., 0].item()
107107
min_prob = max_prob * min_p
108-
modify_logits_for_min_p_filtering(logits_top, min_p)
108+
modify_logits_for_min_p_filtering(logits_top, min_prob)
109109
if temperature != 1.0:
110110
logits_top /= temperature
111111
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)

0 commit comments

Comments
 (0)