@@ -381,13 +381,7 @@ def propose(
381
381
) -> torch .Tensor :
382
382
num_tokens = target_token_ids .shape [0 ]
383
383
batch_size = next_token_ids .shape [0 ]
384
- last_token_indices = common_attn_metadata .query_start_loc [1 :] - 1
385
-
386
- if self .method == "eagle3" :
387
- assert isinstance (self .model , Eagle3LlamaForCausalLM )
388
- target_hidden_states = self .model .combine_hidden_states (
389
- target_hidden_states )
390
- assert target_hidden_states .shape [- 1 ] == self .hidden_size
384
+ block_table = common_attn_metadata .block_table_tensor
391
385
392
386
prefill_shift_tokens = True
393
387
has_prefill = decode_mask is not None and (
@@ -415,15 +409,15 @@ def propose(
415
409
target_positions ,
416
410
target_hidden_states ,
417
411
target_slot_mapping ,
418
- cu_num_tokens ,
412
+ query_start_loc ,
419
413
num_tokens ,
420
414
partial_prefill_mask ,
421
415
) = self ._prepare_adjusted_tensors (
422
416
target_token_ids ,
423
417
target_positions ,
424
418
target_hidden_states ,
425
- target_slot_mapping ,
426
- cu_num_tokens ,
419
+ common_attn_metadata . slot_mapping ,
420
+ common_attn_metadata . query_start_loc ,
427
421
decode_mask ,
428
422
full_prefill_mask ,
429
423
partial_prefill_mask ,
@@ -432,7 +426,20 @@ def propose(
432
426
batch_size ,
433
427
num_tokens ,
434
428
)
435
- batch_size = cu_num_tokens .shape [0 ] - 1
429
+ if (partial_prefill_mask .all ()
430
+ and self .draft_prefill_kv_sharing_from_base ):
431
+ # All requests are partial prefill and
432
+ # KV cache sharing is enabled
433
+ # Skip the rest of the function
434
+ # and return dummy draft tokens
435
+ return torch .zeros (
436
+ (batch_size , self .num_speculative_tokens ),
437
+ dtype = target_token_ids .dtype ,
438
+ device = target_token_ids .device ,
439
+ )
440
+ common_attn_metadata .query_start_loc = query_start_loc
441
+ common_attn_metadata .slot_mapping = target_slot_mapping
442
+ batch_size = query_start_loc .shape [0 ] - 1
436
443
else :
437
444
# Original behavior: shift all tokens by one
438
445
self .input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
@@ -445,20 +452,28 @@ def propose(
445
452
max_num_blocks_per_req = block_table .shape [1 ]
446
453
segment_indices = torch .arange (len (target_positions ),
447
454
device = target_positions .device )
448
- segment_indices = (segment_indices .unsqueeze (0 )
449
- >= cu_num_tokens [:- 1 ].unsqueeze (1 )).sum (
450
- dim = 0 ) - 1
455
+ segment_indices = (
456
+ segment_indices .unsqueeze (0 )
457
+ >= common_attn_metadata .query_start_loc [:- 1 ].unsqueeze (1 )).sum (
458
+ dim = 0 ) - 1
451
459
# Calculate the block table indices
452
460
block_table_indices = (
453
461
target_positions // self .block_size +
454
462
segment_indices * max_num_blocks_per_req )
455
463
block_numbers = block_table .flatten ()[block_table_indices ]
456
464
block_offsets = target_positions % self .block_size
457
- target_slot_mapping = (block_numbers * self .block_size +
458
- block_offsets )
465
+ common_attn_metadata .slot_mapping = (
466
+ block_numbers * self .block_size + block_offsets
467
+ )
459
468
460
469
# Use the original last token indices
461
- last_token_indices = cu_num_tokens [1 :] - 1
470
+ last_token_indices = common_attn_metadata .query_start_loc [1 :] - 1
471
+
472
+ if self .method == "eagle3" :
473
+ assert isinstance (self .model , Eagle3LlamaForCausalLM )
474
+ target_hidden_states = self .model .combine_hidden_states (
475
+ target_hidden_states )
476
+ assert target_hidden_states .shape [- 1 ] == self .hidden_size
462
477
463
478
if not prefill_shift_tokens and has_prefill :
464
479
# Replace the last token with the next token under non-shifting,
0 commit comments