Skip to content

Commit c7bca02

Browse files
authored
Merge pull request #135 from epicfilemcnulty/Add_min_p_sampling
Add min_p sampling method
2 parents 2a3704f + 10bc4d6 commit c7bca02

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ To test generation latency (e.g. batch size = 1) with different sampling strateg
138138
```
139139
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
140140
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
141+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
141142
```
142143

143144
To test generation throughput with random prompts (e.g. large batch size):

benchmarks/benchmark_generation_mamba_simple.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
parser.add_argument("--temperature", type=float, default=1.0)
2323
parser.add_argument("--topk", type=int, default=1)
2424
parser.add_argument("--topp", type=float, default=1.0)
25+
parser.add_argument("--minp", type=float, default=0.0)
2526
parser.add_argument("--repetition-penalty", type=float, default=1.0)
2627
parser.add_argument("--batch", type=int, default=1)
2728
args = parser.parse_args()
@@ -62,6 +63,7 @@
6263
temperature=args.temperature,
6364
top_k=args.topk,
6465
top_p=args.topp,
66+
min_p=args.minp,
6567
repetition_penalty=args.repetition_penalty,
6668
)
6769
else:

mamba_ssm/utils/generation.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def reset(self, max_seqlen, max_batch_size):
3434
self.lengths_per_sample.zero_()
3535

3636

37+
def modify_logits_for_min_p_filtering(logits, min_p):
38+
"""Set the logits for none min_p values to -inf. Done in-place."""
39+
if min_p <= 0.0 or min_p >= 1.0:
40+
return
41+
indices_to_remove = logits < min_p
42+
logits.masked_fill_(indices_to_remove, float("-Inf"))
3743
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
3844
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
3945
def modify_logits_for_top_k_filtering(logits, top_k):
@@ -74,7 +80,7 @@ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_p
7480
return logits
7581

7682

77-
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
83+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
7884
"""Sample from top-k logits.
7985
Arguments:
8086
logits: Tensor of shape (batch_size, vocab_size)
@@ -95,6 +101,14 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
95101
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
96102
]
97103
else:
104+
if min_p > 0.0:
105+
logits_top = logits.clone()
106+
max_prob = logits_top[..., 0].item()
107+
min_prob = max_prob * min_p
108+
modify_logits_for_min_p_filtering(logits_top, min_p)
109+
if temperature != 1.0:
110+
logits_top /= temperature
111+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
98112
# Clone so that when we modify for top_p we don't change the original logits
99113
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
100114
modify_logits_for_top_p_filtering(logits_top, top_p)
@@ -110,6 +124,7 @@ def decode(
110124
max_length,
111125
top_k=1,
112126
top_p=0.0,
127+
min_p=0.0,
113128
temperature=1.0,
114129
repetition_penalty=1.0,
115130
eos_token_id=None,
@@ -180,7 +195,7 @@ def get_logits(input_ids, inference_params):
180195

181196
def sample_tokens(logits, inference_params):
182197
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
183-
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
198+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
184199
else:
185200
token = teacher_outputs[:, inference_params.seqlen_offset]
186201
# return rearrange(token, "b -> b 1")
@@ -236,13 +251,14 @@ def generate(
236251
max_length,
237252
top_k=1,
238253
top_p=0.0,
254+
min_p=0.0,
239255
temperature=1.0,
240256
return_dict_in_generate=False,
241257
output_scores=False,
242258
**kwargs,
243259
):
244260
output = decode(
245-
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
261+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
246262
)
247263
if not output_scores:
248264
output.scores = None

0 commit comments

Comments
 (0)