@@ -147,16 +147,24 @@ def rejection_sample(
147
147
if not sampling_metadata .all_random :
148
148
# Rejection sampling for greedy sampling requests.
149
149
target_argmax = target_probs .argmax (dim = - 1 )
150
- rejection_greedy_sample_pytorch (
151
- output_token_ids ,
152
- cu_num_draft_tokens ,
153
- draft_token_ids ,
154
- target_argmax ,
155
- bonus_token_ids ,
156
- is_greedy ,
157
- max_spec_len ,
158
- # num_warps=1,
159
- )
150
+ if min (num_draft_tokens ) == 1 and max (
151
+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
152
+ rejection_greedy_sample_spec_len_1_pytorch (
153
+ output_token_ids ,
154
+ draft_token_ids ,
155
+ target_argmax ,
156
+ bonus_token_ids ,
157
+ )
158
+ else :
159
+ rejection_greedy_sample_pytorch (
160
+ output_token_ids ,
161
+ cu_num_draft_tokens ,
162
+ draft_token_ids ,
163
+ target_argmax ,
164
+ bonus_token_ids ,
165
+ max_spec_len ,
166
+ is_greedy ,
167
+ )
160
168
if sampling_metadata .all_greedy :
161
169
return output_token_ids
162
170
@@ -284,47 +292,52 @@ def sample_recovered_tokens(
284
292
return recovered_token_ids
285
293
286
294
295
+ def rejection_greedy_sample_spec_len_1_pytorch (
296
+ output_token_ids , # [batch_size, 2]
297
+ draft_token_ids , # [num_tokens]
298
+ target_argmax , # [num_tokens]
299
+ bonus_token_ids , # [batch_size]
300
+ ):
301
+ batch_size = output_token_ids .size (0 )
302
+ num_tokens = draft_token_ids .size (0 )
303
+ assert batch_size == num_tokens
304
+ accept_req_mask = draft_token_ids == target_argmax
305
+ output_token_ids [:, 0 ] = target_argmax
306
+ bonus_token_ids = bonus_token_ids .squeeze (1 )
307
+ output_token_ids [accept_req_mask , 1 ] = bonus_token_ids [accept_req_mask ]
308
+
309
+
287
310
def rejection_greedy_sample_pytorch (
288
- output_token_ids , # [batch_size, max_spec_len + 1]
289
- cu_num_draft_tokens , # [batch_size]
290
- draft_token_ids , # [num_tokens]
291
- target_argmax , # [num_tokens]
292
- bonus_token_ids , # [batch_size]
293
- is_greedy = None , # [batch_size] or None
294
- max_spec_len = None ,
311
+ output_token_ids , # [batch_size, max_spec_len + 1]
312
+ cu_num_draft_tokens , # [batch_size]
313
+ draft_token_ids , # [num_tokens]
314
+ target_argmax , # [num_tokens]
315
+ bonus_token_ids , # [batch_size]
316
+ max_spec_len , # int
317
+ is_greedy = None , # [batch_size] or None
295
318
):
296
319
batch_size = output_token_ids .shape [0 ]
297
-
320
+ device = output_token_ids . device
298
321
if is_greedy is None :
299
- is_greedy = torch .ones (batch_size ,
300
- dtype = torch .bool ,
301
- device = output_token_ids .device )
302
-
303
- for req_idx in range (batch_size ):
304
- if not is_greedy [req_idx ]:
305
- continue
306
-
307
- if req_idx == 0 :
308
- start_idx = 0
309
- else :
310
- start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
311
- end_idx = cu_num_draft_tokens [req_idx ].item ()
312
- num_draft_tokens = end_idx - start_idx
313
-
314
- rejected = False
315
- for pos in range (num_draft_tokens ):
316
- if not rejected :
317
- draft_token_id = draft_token_ids [start_idx + pos ].item ()
318
- target_argmax_id = target_argmax [start_idx + pos ].item ()
319
-
320
- output_token_ids [req_idx , pos ] = target_argmax_id
321
-
322
- if draft_token_id != target_argmax_id :
323
- rejected = True
324
-
325
- if not rejected :
326
- bonus_token_id = bonus_token_ids [req_idx ].item ()
327
- output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
322
+ is_greedy = torch .ones (batch_size , dtype = torch .bool , device = device )
323
+ draft_token_mask = draft_token_ids == target_argmax
324
+ pos_ids = torch .arange (0 , max_spec_len + 1 ,
325
+ device = device ).view (1 , - 1 ).expand (batch_size , - 1 )
326
+ pos_mask = pos_ids < cu_num_draft_tokens .view (- 1 , 1 )
327
+ output_token_mask = torch .zeros ([batch_size , max_spec_len + 1 ],
328
+ dtype = torch .bool ,
329
+ device = device )
330
+ output_token_mask [pos_mask ] = draft_token_mask
331
+ output_token_mask = torch .cumprod (output_token_mask ,
332
+ dim = 1 ) # [batch_size, max_spec_len + 1]
333
+ extra_accept_id = torch .max (
334
+ pos_ids * output_token_mask , dim = 1 , keepdim = True ) + 1
335
+ output_token_mask [extra_accept_id ] = True
336
+ output_token_mask *= is_greedy .view (- 1 , 1 )
337
+ output_token_ids [pos_ids ] = draft_token_ids
338
+ output_token_ids [:, - 1 ] = bonus_token_ids
339
+ output_token_ids = output_token_ids * output_token_mask
340
+ return output_token_ids
328
341
329
342
330
343
def rejection_random_sample_pytorch (
0 commit comments