Skip to content

Commit a766a66

Browse files
committed
rebase and adapt to new attn builder
1 parent 1035e37 commit a766a66

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,7 @@ def propose(
381381
) -> torch.Tensor:
382382
num_tokens = target_token_ids.shape[0]
383383
batch_size = next_token_ids.shape[0]
384-
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
385-
386-
if self.method == "eagle3":
387-
assert isinstance(self.model, Eagle3LlamaForCausalLM)
388-
target_hidden_states = self.model.combine_hidden_states(
389-
target_hidden_states)
390-
assert target_hidden_states.shape[-1] == self.hidden_size
384+
block_table = common_attn_metadata.block_table_tensor
391385

392386
prefill_shift_tokens = True
393387
has_prefill = decode_mask is not None and (
@@ -415,15 +409,15 @@ def propose(
415409
target_positions,
416410
target_hidden_states,
417411
target_slot_mapping,
418-
cu_num_tokens,
412+
query_start_loc,
419413
num_tokens,
420414
partial_prefill_mask,
421415
) = self._prepare_adjusted_tensors(
422416
target_token_ids,
423417
target_positions,
424418
target_hidden_states,
425-
target_slot_mapping,
426-
cu_num_tokens,
419+
common_attn_metadata.slot_mapping,
420+
common_attn_metadata.query_start_loc,
427421
decode_mask,
428422
full_prefill_mask,
429423
partial_prefill_mask,
@@ -432,7 +426,20 @@ def propose(
432426
batch_size,
433427
num_tokens,
434428
)
435-
batch_size = cu_num_tokens.shape[0] - 1
429+
if (partial_prefill_mask.all()
430+
and self.draft_prefill_kv_sharing_from_base):
431+
# All requests are partial prefill and
432+
# KV cache sharing is enabled
433+
# Skip the rest of the function
434+
# and return dummy draft tokens
435+
return torch.zeros(
436+
(batch_size, self.num_speculative_tokens),
437+
dtype=target_token_ids.dtype,
438+
device=target_token_ids.device,
439+
)
440+
common_attn_metadata.query_start_loc = query_start_loc
441+
common_attn_metadata.slot_mapping = target_slot_mapping
442+
batch_size = query_start_loc.shape[0] - 1
436443
else:
437444
# Original behavior: shift all tokens by one
438445
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@@ -445,20 +452,28 @@ def propose(
445452
max_num_blocks_per_req = block_table.shape[1]
446453
segment_indices = torch.arange(len(target_positions),
447454
device=target_positions.device)
448-
segment_indices = (segment_indices.unsqueeze(0)
449-
>= cu_num_tokens[:-1].unsqueeze(1)).sum(
450-
dim=0) - 1
455+
segment_indices = (
456+
segment_indices.unsqueeze(0)
457+
>= common_attn_metadata.query_start_loc[:-1].unsqueeze(1)).sum(
458+
dim=0) - 1
451459
# Calculate the block table indices
452460
block_table_indices = (
453461
target_positions // self.block_size +
454462
segment_indices * max_num_blocks_per_req)
455463
block_numbers = block_table.flatten()[block_table_indices]
456464
block_offsets = target_positions % self.block_size
457-
target_slot_mapping = (block_numbers * self.block_size +
458-
block_offsets)
465+
common_attn_metadata.slot_mapping = (
466+
block_numbers * self.block_size + block_offsets
467+
)
459468

460469
# Use the original last token indices
461-
last_token_indices = cu_num_tokens[1:] - 1
470+
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
471+
472+
if self.method == "eagle3":
473+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
474+
target_hidden_states = self.model.combine_hidden_states(
475+
target_hidden_states)
476+
assert target_hidden_states.shape[-1] == self.hidden_size
462477

463478
if not prefill_shift_tokens and has_prefill:
464479
# Replace the last token with the next token under non-shifting,

0 commit comments

Comments
 (0)