Skip to content

Commit 2a9d02e

Browse files
authored
[Bugfix] eagle and eagle3 spec decode failures and enable e2e test (#2979)
### What this PR does / why we need it? - Fix the bug #2978 - Enable e2e test, - Adapt to scenarios where Speculative tokens are greater than 2, - Fix the bug that causes Eagle3 inference failures under high concurrency and improve the acceptance rate of draft models, by #2794 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with new added/existing test. Co-authored-by: hukongyi [hukongyi@cmbchina.com](mailto:hukongyi@cmbchina.com) Co-authored-by: guanyuzhu [zhuguanyu@huawei.com](mailto:zhuguanyu@huawei.com) Co-authored-by: liumail680 [liumail680@163.com](mailto:liumail680@163.com) - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 --------- Signed-off-by: Icey <1790571317@qq.com>
1 parent ac1c2cd commit 2a9d02e

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from __future__ import annotations
33

4+
import os
45
import random
56
from typing import Any
67

@@ -9,6 +10,8 @@
910

1011
from tests.e2e.conftest import VllmRunner
1112

13+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
14+
1215

1316
@pytest.fixture
1417
def test_prompts():
@@ -99,7 +102,6 @@ def test_ngram_correctness(
99102
assert matches > int(0.7 * len(ref_outputs))
100103

101104

102-
@pytest.mark.skipif(True, reason="oom in CI, fix me")
103105
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
104106
def test_eagle_correctness(
105107
test_prompts: list[list[dict[str, Any]]],
@@ -111,8 +113,6 @@ def test_eagle_correctness(
111113
Compare the outputs of a original LLM and a speculative LLM
112114
should be the same when using eagle speculative decoding.
113115
'''
114-
if not use_eagle3:
115-
pytest.skip("Not current support for the test.")
116116

117117
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
118118
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@@ -121,7 +121,6 @@ def test_eagle_correctness(
121121
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
122122
with VllmRunner(
123123
model_name,
124-
trust_remote_code=True,
125124
enable_chunked_prefill=True,
126125
max_num_seqs=1,
127126
max_num_batched_tokens=2048,

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2323
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2424
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
25+
from vllm_ascend.utils import vllm_version_is
2526

2627
PADDING_SLOT_ID = -1
2728

@@ -139,8 +140,6 @@ def generate_token_ids(self,
139140
hidden_states: torch.Tensor = None,
140141
attn_metadata=None,
141142
aux_hidden_states: torch.Tensor = None):
142-
if self.name == SpecDcodeType.EAGLE:
143-
raise NotImplementedError("Eagle Is Not Supported Yet.")
144143

145144
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
146145
next_token_ids: list[int] = []
@@ -355,8 +354,12 @@ def _get_eagle_atten_dict(
355354
decode_token_per_req=self.runner.decode_token_per_req,
356355
num_computed_tokens_cpu=None,
357356
seq_lens=None)
358-
attn_metadata_i = self.runner.attn_metadata_builder.build(
359-
common_attn_metadata, self.runner.get_model())
357+
if vllm_version_is("0.10.2"):
358+
builder = self.runner.attn_groups[0][0].metadata_builder
359+
else:
360+
builder = self.runner.attn_groups[0][0].get_metadata_builder()
361+
attn_metadata_i = builder.build(0, common_attn_metadata,
362+
self.runner.get_model())
360363
for layer_name in kv_cache_group_spec.layer_names:
361364
attn_metadata[layer_name] = attn_metadata_i
362365

@@ -418,16 +421,19 @@ def _propose(
418421
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
419422
# Replace the last token with the next token.
420423
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
421-
self.input_ids[last_token_indices] = next_token_ids[0]
424+
self.input_ids[last_token_indices] = next_token_ids
425+
seq_lens = (target_positions[last_token_indices] + 1).int()
422426

423427
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
424428
max_query_len = query_lens.max().item()
429+
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
430+
seq_lens, target_positions, self.vllm_config.model_config.dtype,
431+
self.device)
425432

426433
common_attn_metadata = AscendCommonAttentionMetadata(
427-
query_start_loc=self.runner.query_start_loc[:batch_size + 1],
428-
query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size +
429-
1],
430-
seq_lens_cpu=self.runner.seq_lens_cpu,
434+
query_start_loc=cu_num_tokens.to(device),
435+
query_start_loc_cpu=cu_num_tokens,
436+
seq_lens_cpu=seq_lens.cpu(),
431437
max_query_len=max_query_len,
432438
num_reqs=batch_size,
433439
num_actual_tokens=num_tokens,
@@ -436,15 +442,19 @@ def _propose(
436442
get_device_tensor(),
437443
slot_mapping=target_slot_mapping,
438444
positions=target_positions,
439-
attn_mask=self.runner.attn_mask,
445+
attn_mask=attn_mask,
440446
spec_attn_mask=self.runner.spec_attn_mask,
441447
attn_state=self.runner.attn_state,
442448
decode_token_per_req=self.runner.decode_token_per_req,
443449
num_computed_tokens_cpu=None,
444450
seq_lens=None)
445451
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
446-
attn_metadata = self.runner.attn_metadata_builder.build(
447-
common_attn_metadata, self.runner.model)
452+
if vllm_version_is("0.10.2"):
453+
builder = self.runner.attn_groups[0][0].metadata_builder
454+
else:
455+
builder = self.runner.attn_groups[0][0].get_metadata_builder()
456+
attn_metadata = builder.build(0, common_attn_metadata,
457+
self.runner.get_model())
448458
if self.use_cuda_graph and \
449459
num_tokens <= self.cudagraph_batch_sizes[-1]:
450460
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -471,7 +481,10 @@ def _propose(
471481
hidden_states=self.hidden_states[:num_input_tokens],
472482
)
473483
sample_hidden_states = last_hidden_states[last_token_indices]
474-
logits = self.model.compute_logits(sample_hidden_states, None)
484+
if vllm_version_is("0.10.2"):
485+
logits = self.model.compute_logits(sample_hidden_states, None)
486+
else:
487+
logits = self.model.compute_logits(sample_hidden_states)
475488
draft_token_ids = logits.argmax(dim=-1)
476489

477490
# Early exit if there is only one draft token to be generated.
@@ -501,9 +514,8 @@ def _propose(
501514
attn_metadata.num_actual_tokens = batch_size
502515
attn_metadata.max_query_len = 1
503516
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
504-
505-
if self.vllm_config.speculative_config.num_speculative_tokens > 2:
506-
raise ValueError("Speculative tokens > 2 are not supported yet.")
517+
query_lens.fill_(1)
518+
attn_metadata.query_lens = query_lens
507519

508520
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
509521
for now_speculative in range(
@@ -558,9 +570,8 @@ def _propose(
558570
self.input_ids[:batch_size] = input_ids
559571
self.positions[:batch_size] = clamped_positions
560572
self.hidden_states[:batch_size] = hidden_states
561-
positions = positions_cpu.to(device)
562573
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
563-
attn_metadata.seq_lens, positions,
574+
attn_metadata.seq_lens, positions_cpu,
564575
self.vllm_config.model_config.dtype, self.device)
565576

566577
attn_metadata.attn_mask = attn_mask
@@ -577,8 +588,12 @@ def _propose(
577588
hidden_states=self.hidden_states[:input_batch_size],
578589
)
579590
hidden_states = hidden_states[:batch_size]
580-
logits = self.model.compute_logits(last_hidden_states[:batch_size],
581-
None)
591+
if vllm_version_is("0.10.2"):
592+
logits = self.model.compute_logits(
593+
last_hidden_states[:batch_size], None)
594+
else:
595+
logits = self.model.compute_logits(
596+
last_hidden_states[:batch_size])
582597

583598
# TODO(wenlong): get more than one token for tree attention
584599
draft_token_ids = logits.argmax(dim=-1)
@@ -652,7 +667,8 @@ def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
652667
dtype=torch.int32,
653668
device=out_tensor.device) + offset_tensor
654669
values_to_store = torch.tensor(
655-
index_start, dtype=torch.int32,
670+
index_start + global_start_offset,
671+
dtype=torch.int32,
656672
device=out_tensor.device) + offset_tensor
657673
mask = (target_indices >= start_pos) & \
658674
(target_indices < end_pos) & \

0 commit comments

Comments
 (0)