Skip to content

Commit eaebd43

Browse files
committed
rebase and adapt to new attn builder
1 parent 021d4ab commit eaebd43

File tree

1 file changed

+45
-17
lines changed

1 file changed

+45
-17
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,7 @@ def propose(
381381
) -> torch.Tensor:
382382
num_tokens = target_token_ids.shape[0]
383383
batch_size = next_token_ids.shape[0]
384-
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
385-
386-
if self.method == "eagle3":
387-
assert isinstance(self.model, Eagle3LlamaForCausalLM)
388-
target_hidden_states = self.model.combine_hidden_states(
389-
target_hidden_states)
390-
assert target_hidden_states.shape[-1] == self.hidden_size
384+
block_table = common_attn_metadata.block_table_tensor
391385

392386
prefill_shift_tokens = True
393387
has_prefill = decode_mask is not None and (
@@ -415,15 +409,15 @@ def propose(
415409
target_positions,
416410
target_hidden_states,
417411
target_slot_mapping,
418-
cu_num_tokens,
412+
query_start_loc,
419413
num_tokens,
420414
partial_prefill_mask,
421415
) = self._prepare_adjusted_tensors(
422416
target_token_ids,
423417
target_positions,
424418
target_hidden_states,
425-
target_slot_mapping,
426-
cu_num_tokens,
419+
common_attn_metadata.slot_mapping,
420+
common_attn_metadata.query_start_loc,
427421
decode_mask,
428422
full_prefill_mask,
429423
partial_prefill_mask,
@@ -432,7 +426,28 @@ def propose(
432426
batch_size,
433427
num_tokens,
434428
)
435-
batch_size = cu_num_tokens.shape[0] - 1
429+
if (partial_prefill_mask.all()
430+
and self.draft_prefill_kv_sharing_from_base):
431+
# All requests are partial prefill and
432+
# KV cache sharing is enabled
433+
# Skip the rest of the function
434+
# and return dummy draft tokens
435+
return torch.zeros(
436+
(batch_size, self.num_speculative_tokens),
437+
dtype=target_token_ids.dtype,
438+
device=target_token_ids.device,
439+
)
440+
441+
query_start_loc_cpu = query_start_loc.to("cpu", non_blocking=True)
442+
max_num_tokens = (query_start_loc_cpu[1:] -
443+
query_start_loc_cpu[:-1]).max().item()
444+
445+
common_attn_metadata.query_start_loc = query_start_loc
446+
common_attn_metadata.slot_mapping = target_slot_mapping
447+
common_attn_metadata.query_start_loc_cpu = query_start_loc_cpu
448+
common_attn_metadata.num_actual_tokens = num_tokens
449+
common_attn_metadata.max_query_len = max_num_tokens
450+
batch_size = query_start_loc_cpu.shape[0] - 1
436451
else:
437452
# Original behavior: shift all tokens by one
438453
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@@ -445,20 +460,33 @@ def propose(
445460
max_num_blocks_per_req = block_table.shape[1]
446461
segment_indices = torch.arange(len(target_positions),
447462
device=target_positions.device)
448-
segment_indices = (segment_indices.unsqueeze(0)
449-
>= cu_num_tokens[:-1].unsqueeze(1)).sum(
450-
dim=0) - 1
463+
segment_indices = (
464+
segment_indices.unsqueeze(0)
465+
>= common_attn_metadata.query_start_loc[:-1].unsqueeze(1)).sum(
466+
dim=0) - 1
451467
# Calculate the block table indices
452468
block_table_indices = (
453469
target_positions // self.block_size +
454470
segment_indices * max_num_blocks_per_req)
455471
block_numbers = block_table.flatten()[block_table_indices]
456472
block_offsets = target_positions % self.block_size
457-
target_slot_mapping = (block_numbers * self.block_size +
458-
block_offsets)
473+
common_attn_metadata.slot_mapping = (
474+
block_numbers * self.block_size + block_offsets
475+
)
459476

460477
# Use the original last token indices
461-
last_token_indices = cu_num_tokens[1:] - 1
478+
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
479+
if not prefill_shift_tokens:
480+
seq_lens = (target_positions[last_token_indices] + 1).int()
481+
seq_lens_cpu = seq_lens.to("cpu", non_blocking=True)
482+
common_attn_metadata.seq_lens = seq_lens
483+
common_attn_metadata.seq_lens_cpu = seq_lens_cpu
484+
485+
if self.method == "eagle3":
486+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
487+
target_hidden_states = self.model.combine_hidden_states(
488+
target_hidden_states)
489+
assert target_hidden_states.shape[-1] == self.hidden_size
462490

463491
if not prefill_shift_tokens and has_prefill:
464492
# Replace the last token with the next token under non-shifting,

0 commit comments

Comments
 (0)