@@ -34,6 +34,12 @@ def reset(self, max_seqlen, max_batch_size):
34
34
self .lengths_per_sample .zero_ ()
35
35
36
36
37
+ def modify_logits_for_min_p_filtering (logits , min_p ):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0 :
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits .masked_fill_ (indices_to_remove , float ("-Inf" ))
37
43
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
38
44
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
39
45
def modify_logits_for_top_k_filtering (logits , top_k ):
@@ -74,7 +80,7 @@ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_p
74
80
return logits
75
81
76
82
77
- def sample (logits , top_k = 1 , top_p = 0.0 , temperature = 1.0 ):
83
+ def sample (logits , top_k = 1 , top_p = 0.0 , min_p = 0.0 , temperature = 1.0 ):
78
84
"""Sample from top-k logits.
79
85
Arguments:
80
86
logits: Tensor of shape (batch_size, vocab_size)
@@ -95,6 +101,14 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
95
101
torch .multinomial (torch .softmax (logits_top , dim = - 1 ), num_samples = 1 ).squeeze (dim = - 1 ),
96
102
]
97
103
else :
104
+ if min_p > 0.0 :
105
+ logits_top = logits .clone ()
106
+ max_prob = logits_top [..., 0 ].item ()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering (logits_top , min_p )
109
+ if temperature != 1.0 :
110
+ logits_top /= temperature
111
+ return torch .multinomial (torch .softmax (logits_top , dim = - 1 ), num_samples = 1 ).squeeze (dim = - 1 )
98
112
# Clone so that when we modify for top_p we don't change the original logits
99
113
logits_top = logits / temperature if temperature != 1.0 else logits .clone ()
100
114
modify_logits_for_top_p_filtering (logits_top , top_p )
@@ -110,6 +124,7 @@ def decode(
110
124
max_length ,
111
125
top_k = 1 ,
112
126
top_p = 0.0 ,
127
+ min_p = 0.0 ,
113
128
temperature = 1.0 ,
114
129
repetition_penalty = 1.0 ,
115
130
eos_token_id = None ,
@@ -180,7 +195,7 @@ def get_logits(input_ids, inference_params):
180
195
181
196
def sample_tokens (logits , inference_params ):
182
197
if teacher_outputs is None or teacher_output_len <= inference_params .seqlen_offset :
183
- token = sample (logits , top_k = top_k , top_p = top_p , temperature = temperature )
198
+ token = sample (logits , top_k = top_k , top_p = top_p , min_p = min_p , temperature = temperature )
184
199
else :
185
200
token = teacher_outputs [:, inference_params .seqlen_offset ]
186
201
# return rearrange(token, "b -> b 1")
@@ -242,7 +257,7 @@ def generate(
242
257
** kwargs ,
243
258
):
244
259
output = decode (
245
- input_ids , self , max_length , top_k = top_k , top_p = top_p , temperature = temperature , ** kwargs
260
+ input_ids , self , max_length , top_k = top_k , top_p = top_p , min_p = min_p , temperature = temperature , ** kwargs
246
261
)
247
262
if not output_scores :
248
263
output .scores = None
0 commit comments