Skip to content

Commit 2430b17

Browse files
Add min_p sampling method
1 parent 86a3a90 commit 2430b17

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

mamba_ssm/utils/generation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def reset(self, max_seqlen, max_batch_size):
3434
self.lengths_per_sample.zero_()
3535

3636

37+
def modify_logits_for_min_p_filtering(logits, min_p):
38+
"""Set the logits for none min_p values to -inf. Done in-place."""
39+
if min_p <= 0.0 or min_p >= 1.0:
40+
return
41+
indices_to_remove = logits < min_p
42+
logits.masked_fill_(indices_to_remove, float("-Inf"))
3743
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
3844
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
3945
def modify_logits_for_top_k_filtering(logits, top_k):
@@ -74,7 +80,7 @@ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_p
7480
return logits
7581

7682

77-
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
83+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
7884
"""Sample from top-k logits.
7985
Arguments:
8086
logits: Tensor of shape (batch_size, vocab_size)
@@ -95,6 +101,14 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
95101
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
96102
]
97103
else:
104+
if min_p > 0.0:
105+
logits_top = logits.clone()
106+
max_prob = logits_top[..., 0].item()
107+
min_prob = max_prob * min_p
108+
modify_logits_for_min_p_filtering(logits_top, min_p)
109+
if temperature != 1.0:
110+
logits_top /= temperature
111+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
98112
# Clone so that when we modify for top_p we don't change the original logits
99113
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
100114
modify_logits_for_top_p_filtering(logits_top, top_p)
@@ -110,6 +124,7 @@ def decode(
110124
max_length,
111125
top_k=1,
112126
top_p=0.0,
127+
min_p=0.0,
113128
temperature=1.0,
114129
repetition_penalty=1.0,
115130
eos_token_id=None,
@@ -180,7 +195,7 @@ def get_logits(input_ids, inference_params):
180195

181196
def sample_tokens(logits, inference_params):
182197
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
183-
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
198+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
184199
else:
185200
token = teacher_outputs[:, inference_params.seqlen_offset]
186201
# return rearrange(token, "b -> b 1")
@@ -242,7 +257,7 @@ def generate(
242257
**kwargs,
243258
):
244259
output = decode(
245-
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
260+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
246261
)
247262
if not output_scores:
248263
output.scores = None

0 commit comments

Comments
 (0)