Skip to content

Commit 8dd8fd2

Browse files
ShangmingCailulmer
authored andcommitted
[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (vllm-project#12304)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 3a44f4c commit 8dd8fd2

File tree

8 files changed

+273
-18
lines changed

8 files changed

+273
-18
lines changed

docs/source/features/spec_decode.md

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
175175
llm = LLM(
176176
model="meta-llama/Meta-Llama-3-8B-Instruct",
177177
tensor_parallel_size=4,
178-
speculative_model="path/to/modified/eagle/model",
178+
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
179179
speculative_draft_tensor_parallel_size=1,
180180
)
181181

@@ -190,14 +190,12 @@ for output in outputs:
190190

191191
A few important things to consider when using the EAGLE based draft models:
192192

193-
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be
194-
used directly with vLLM due to differences in the expected layer names and model definition.
195-
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d)
196-
to convert them. Note that this script does not modify the model's weights.
197-
198-
In the above example, use the script to first convert
199-
the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model
200-
and then use the converted checkpoint as the draft model in vLLM.
193+
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
194+
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
195+
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
196+
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
197+
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
198+
the latest version of vLLM, please leave a comment or raise an issue.
201199

202200
2. The EAGLE based draft models need to be run without tensor parallelism
203201
(i.e. speculative_draft_tensor_parallel_size is set to 1), although

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
305305
batch_size, output_len, seed)
306306

307307

308+
@pytest.mark.parametrize(
309+
"common_llm_kwargs",
310+
[{
311+
# Skip cuda graph recording for fast test.
312+
"enforce_eager": True,
313+
314+
# Print spec metrics.
315+
"disable_log_stats": False,
316+
317+
# Precision
318+
"dtype": "float16",
319+
320+
# Main model
321+
"model_name": "meta-llama/Llama-2-7b-chat-hf",
322+
}])
323+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
324+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
325+
@pytest.mark.parametrize("test_llm_kwargs", [
326+
{
327+
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
328+
"num_speculative_tokens": MAX_SPEC_TOKENS,
329+
},
330+
])
331+
@pytest.mark.parametrize(
332+
"output_len",
333+
[
334+
# Use smaller output len for fast test.
335+
32,
336+
])
337+
@pytest.mark.parametrize("batch_size", [1, 5])
338+
@pytest.mark.parametrize("seed", [1])
339+
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
340+
per_test_common_llm_kwargs,
341+
baseline_llm_kwargs,
342+
test_llm_kwargs, batch_size: int,
343+
output_len: int, seed: int):
344+
345+
run_equality_correctness_test(vllm_runner,
346+
common_llm_kwargs,
347+
per_test_common_llm_kwargs,
348+
baseline_llm_kwargs,
349+
test_llm_kwargs,
350+
batch_size,
351+
output_len,
352+
seed,
353+
temperature=0.0)
354+
355+
356+
@pytest.mark.parametrize(
357+
"common_llm_kwargs",
358+
[{
359+
# Skip cuda graph recording for fast test.
360+
"enforce_eager": True,
361+
362+
# Print spec metrics.
363+
"disable_log_stats": False,
364+
365+
# Precision
366+
"dtype": "float16",
367+
368+
# Main model
369+
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
370+
}])
371+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
372+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
373+
@pytest.mark.parametrize("test_llm_kwargs", [
374+
{
375+
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
376+
"num_speculative_tokens": MAX_SPEC_TOKENS,
377+
},
378+
])
379+
@pytest.mark.parametrize(
380+
"output_len",
381+
[
382+
# Use smaller output len for fast test.
383+
32,
384+
])
385+
@pytest.mark.parametrize("batch_size", [1, 5])
386+
@pytest.mark.parametrize("seed", [1])
387+
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
388+
per_test_common_llm_kwargs,
389+
baseline_llm_kwargs,
390+
test_llm_kwargs, batch_size: int,
391+
output_len: int, seed: int):
392+
393+
run_equality_correctness_test(vllm_runner,
394+
common_llm_kwargs,
395+
per_test_common_llm_kwargs,
396+
baseline_llm_kwargs,
397+
test_llm_kwargs,
398+
batch_size,
399+
output_len,
400+
seed,
401+
temperature=0.0)
402+
403+
404+
@pytest.mark.parametrize(
405+
"common_llm_kwargs",
406+
[{
407+
# Skip cuda graph recording for fast test.
408+
"enforce_eager": True,
409+
410+
# Print spec metrics.
411+
"disable_log_stats": False,
412+
413+
# Precision
414+
"dtype": "float16",
415+
416+
# Main model
417+
"model_name": "Qwen/Qwen2-7B-Instruct",
418+
}])
419+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
420+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
421+
@pytest.mark.parametrize("test_llm_kwargs", [
422+
{
423+
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
424+
"num_speculative_tokens": MAX_SPEC_TOKENS,
425+
},
426+
])
427+
@pytest.mark.parametrize(
428+
"output_len",
429+
[
430+
# Use smaller output len for fast test.
431+
32,
432+
])
433+
@pytest.mark.parametrize("batch_size", [1, 5])
434+
@pytest.mark.parametrize("seed", [1])
435+
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
436+
per_test_common_llm_kwargs,
437+
baseline_llm_kwargs,
438+
test_llm_kwargs, batch_size: int,
439+
output_len: int, seed: int):
440+
441+
run_equality_correctness_test(vllm_runner,
442+
common_llm_kwargs,
443+
per_test_common_llm_kwargs,
444+
baseline_llm_kwargs,
445+
test_llm_kwargs,
446+
batch_size,
447+
output_len,
448+
seed,
449+
temperature=0.0)
450+
451+
308452
if __name__ == "__main__":
309453
import pytest
310454
pytest.main([__file__])

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313
from vllm.model_executor.utils import set_random_seed
1414
from vllm.sequence import ExecuteModelRequest, SequenceOutput
1515
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
16+
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
1617
from vllm.spec_decode.interfaces import SpeculativeProposals
1718
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
1819
SpecDecodeWorkerMetrics)
1920
from vllm.spec_decode.multi_step_worker import MultiStepWorker
2021
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
2122
split_num_cache_blocks_evenly)
23+
from vllm.worker.worker import Worker
2224

2325
from .test_utils import mock_spec_decode_sampler
24-
from .utils import create_batch, create_sampler_output_list, mock_worker
26+
from .utils import (create_batch, create_sampler_output_list, create_worker,
27+
mock_worker)
2528

2629

2730
@pytest.mark.parametrize('k', [1, 2, 6])
@@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
905908
worker.execute_model(execute_model_req=execute_model_req)
906909
# but first draft still counted
907910
assert draft_worker.get_spec_proposals.call_count == 1
911+
912+
913+
def test_correctly_load_weight_for_eagle():
914+
"""
915+
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
916+
"""
917+
seed = 100
918+
block_size = 32
919+
num_gpu_blocks = 8096 // block_size
920+
target_worker = create_worker(
921+
Worker,
922+
"JackFram/llama-68m",
923+
block_size,
924+
num_gpu_blocks,
925+
seed,
926+
)
927+
draft_worker = create_worker(
928+
MultiStepWorker,
929+
"abhigoyal/vllm-eagle-llama-68m-random",
930+
block_size,
931+
num_gpu_blocks,
932+
seed,
933+
model_runner_cls=TP1DraftModelRunner,
934+
)
935+
936+
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
937+
worker = SpecDecodeWorker(draft_worker,
938+
target_worker,
939+
spec_decode_sampler,
940+
disable_logprobs=False)
941+
worker.proposer_worker.maybe_load_lm_head_weight(
942+
target_worker.model_runner.model.lm_head.weight.data)
943+
assert torch.allclose(
944+
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
945+
worker.scorer_worker.model_runner.model.lm_head.weight.data)

vllm/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,15 @@ def maybe_create_spec_config(
18331833

18341834
draft_hf_config = draft_model_config.hf_config
18351835

1836+
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
1837+
if "eagle-" in draft_model_config.model.lower():
1838+
from vllm.transformers_utils.configs.eagle import EAGLEConfig
1839+
if isinstance(draft_model_config.hf_config, EAGLEConfig):
1840+
pass
1841+
else:
1842+
eagle_config = EAGLEConfig(draft_model_config.hf_config)
1843+
draft_model_config.hf_config = eagle_config
1844+
18361845
if (num_speculative_tokens is not None
18371846
and hasattr(draft_hf_config, "num_lookahead_tokens")):
18381847
draft_hf_config.num_lookahead_tokens = num_speculative_tokens

vllm/model_executor/models/eagle.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.attention.backends.abstract import AttentionMetadata
99
from vllm.config import VllmConfig
10+
from vllm.logger import init_logger
1011
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1112
from vllm.model_executor.layers.sampler import SamplerOutput
1213
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -18,6 +19,8 @@
1819

1920
from .utils import maybe_prefix
2021

22+
logger = init_logger(__name__)
23+
2124

2225
class DummyInputLayerNorm(nn.Module):
2326

@@ -190,8 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
190193
default_weight_loader)
191194
weight_loader(self.fc.bias, loaded_weight)
192195
else:
193-
raise ValueError("Found bias in the loaded weights "
194-
"but the model config doesn't have bias")
196+
logger.warning_once("Found bias in the loaded weights but "
197+
"the model config doesn't have bias.")
195198
elif name.startswith("model.lm_head.") or name.startswith(
196199
"model.model."):
197200
model_weights[name.split("model.", 1)[-1]] = loaded_weight
@@ -200,12 +203,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
200203
else:
201204
model_weights[f"model.{name}"] = loaded_weight
202205

203-
lm_head_weight = model_weights.pop("lm_head.weight")
206+
if "lm_head.weight" in model_weights:
207+
lm_head_weight = model_weights.pop("lm_head.weight")
208+
209+
if self.token_map is not None and\
210+
lm_head_weight.shape[0] > self.token_map.shape[0]:
204211

205-
if self.token_map is not None and\
206-
lm_head_weight.shape[0] > self.token_map.shape[0]:
212+
lm_head_weight = lm_head_weight[self.token_map]
207213

208-
lm_head_weight = lm_head_weight[self.token_map]
214+
else:
215+
# NOTE(Shangming): initialize the placeholder for lm_head weight.
216+
lm_head_weight = torch.zeros(
217+
self.lm_head.org_vocab_size,
218+
self.lm_head.embedding_dim,
219+
dtype=self.config.torch_dtype,
220+
)
209221

210222
weight_loader = getattr(self.lm_head.weight, "weight_loader",
211223
default_weight_loader)

vllm/spec_decode/multi_step_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from vllm.model_executor.layers.sampler import SamplerOutput
10+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1011
from vllm.platforms import current_platform
1112
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
1213
SequenceGroupMetadata)
@@ -386,3 +387,14 @@ def _raise_if_unsupported(
386387
execute_model_req.seq_group_metadata_list):
387388
raise NotImplementedError(
388389
"MultiStepWorker does not support beam search.")
390+
391+
def maybe_load_lm_head_weight(
392+
self,
393+
lm_head_weight: torch.Tensor,
394+
) -> None:
395+
weight_loader = getattr(
396+
self.worker.model_runner.model_runner.model.lm_head.weight,
397+
"weight_loader", default_weight_loader)
398+
weight_loader(
399+
self.worker.model_runner.model_runner.model.lm_head.weight,
400+
lm_head_weight)

vllm/spec_decode/smaller_tp_proposer_worker.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
patch_tensor_parallel_group)
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.layers.sampler import SamplerOutput
13+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1314
from vllm.sequence import ExecuteModelRequest
1415
from vllm.spec_decode.interfaces import SpeculativeProposals
1516
from vllm.spec_decode.multi_step_worker import MultiStepWorker
@@ -173,3 +174,21 @@ def get_cache_block_size_bytes(self) -> int:
173174
@property
174175
def vocab_size(self) -> int:
175176
return self._worker.vocab_size
177+
178+
def maybe_load_lm_head_weight(
179+
self,
180+
lm_head_weight: torch.Tensor,
181+
) -> None:
182+
if self._is_dummy:
183+
return
184+
185+
with self._patch_tensor_parallel_group():
186+
weight_loader = getattr(
187+
self._worker.worker.model_runner.model_runner.model.\
188+
lm_head.weight,
189+
"weight_loader",
190+
default_weight_loader)
191+
weight_loader(
192+
self._worker.worker.model_runner.model_runner.model.\
193+
lm_head.weight,
194+
lm_head_weight)

0 commit comments

Comments
 (0)