Skip to content

Commit 1bd04e4

Browse files
committed
[v1] Support deepseek with eagle
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent 10be209 commit 1bd04e4

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import VllmConfig
11+
from vllm.distributed.parallel_state import get_pp_group
12+
from vllm.model_executor.layers.fused_moe import FusedMoE
13+
from vllm.model_executor.layers.layernorm import RMSNorm
14+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
15+
from vllm.model_executor.layers.vocab_parallel_embedding import (
16+
VocabParallelEmbedding)
17+
from vllm.model_executor.model_loader.weight_utils import (
18+
default_weight_loader, maybe_remap_kv_scale_name)
19+
from vllm.model_executor.models.deepseek_v2 import (
20+
DeepseekV2DecoderLayer, DeepseekV3ForCausalLM,
21+
get_spec_layer_idx_from_weight_name)
22+
23+
from .utils import AutoWeightsLoader, maybe_prefix
24+
25+
26+
@support_torch_compile
27+
class DeepseekV2Model(nn.Module):
28+
29+
def __init__(
30+
self,
31+
*,
32+
vllm_config: VllmConfig,
33+
prefix: str = "",
34+
start_layer_id: int = 0,
35+
) -> None:
36+
super().__init__()
37+
self.config = vllm_config. \
38+
speculative_config.draft_model_config.hf_config
39+
model_config = vllm_config.model_config
40+
cache_config = vllm_config.cache_config
41+
quant_config = vllm_config.quant_config
42+
self.vocab_size = self.config.vocab_size
43+
44+
self.embed_tokens = VocabParallelEmbedding(
45+
self.config.vocab_size,
46+
self.config.hidden_size,
47+
quant_config=quant_config,
48+
prefix=maybe_prefix(prefix, "embed_tokens"),
49+
)
50+
51+
self.layers = nn.ModuleList([
52+
DeepseekV2DecoderLayer(
53+
self.config,
54+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
55+
model_config=model_config,
56+
cache_config=cache_config,
57+
quant_config=quant_config,
58+
) for i in range(self.config.num_hidden_layers)
59+
])
60+
61+
self.fc = nn.Linear(
62+
self.config.model.hidden_size * 2,
63+
self.config.model.hidden_size,
64+
bias=False,
65+
)
66+
67+
self.enorm = RMSNorm(self.config.hidden_size,
68+
eps=self.config.rms_norm_eps)
69+
self.hnorm = RMSNorm(self.config.hidden_size,
70+
eps=self.config.rms_norm_eps)
71+
self.norm = RMSNorm(self.config.hidden_size,
72+
eps=self.config.rms_norm_eps)
73+
74+
def forward(
75+
self,
76+
input_ids: torch.Tensor,
77+
positions: torch.Tensor,
78+
hidden_states: torch.Tensor,
79+
) -> tuple[torch.Tensor, torch.Tensor]:
80+
input_embeds = self.embed_tokens(input_ids)
81+
82+
inputs = torch.cat(
83+
[self.enorm(input_embeds),
84+
self.hnorm(hidden_states)], dim=-1)
85+
hidden_states = self.fc(inputs)
86+
87+
# masking inputs at position=0
88+
hidden_states[positions == 0] = 0
89+
residual = None
90+
for layer in self.layers:
91+
hidden_states, residual = layer(
92+
positions,
93+
hidden_states,
94+
residual,
95+
)
96+
hidden_states, _ = self.norm(hidden_states, residual)
97+
return hidden_states, hidden_states
98+
99+
def load_weights(self, weights: Iterable[tuple[str,
100+
torch.Tensor]]) -> set[str]:
101+
stacked_params_mapping = [
102+
# (param_name, shard_name, shard_id)
103+
("gate_up_proj", "gate_proj", 0),
104+
("gate_up_proj", "up_proj", 1),
105+
]
106+
107+
# Params for weights, fp8 weight scales, fp8 activation scales
108+
# (param_name, weight_name, expert_id, shard_id)
109+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
110+
ckpt_gate_proj_name="gate_proj",
111+
ckpt_down_proj_name="down_proj",
112+
ckpt_up_proj_name="up_proj",
113+
num_experts=self.config.n_routed_experts)
114+
115+
params_dict = dict(self.named_parameters())
116+
loaded_params: set[str] = set()
117+
for name, loaded_weight in weights:
118+
if "rotary_emb.inv_freq" in name:
119+
continue
120+
121+
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
122+
if spec_layer is not None:
123+
continue # skip spec decode layers for main model
124+
125+
for param_name, weight_name, shard_id in stacked_params_mapping:
126+
# Skip non-stacked layers and experts (experts handled below).
127+
if weight_name not in name:
128+
continue
129+
# We have mlp.experts[0].gate_proj in the checkpoint.
130+
# Since we handle the experts below in expert_params_mapping,
131+
# we need to skip here BEFORE we update the name, otherwise
132+
# name will be updated to mlp.experts[0].gate_up_proj, which
133+
# will then be updated below in expert_params_mapping
134+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
135+
if ("mlp.experts." in name) and name not in params_dict:
136+
continue
137+
name = name.replace(weight_name, param_name)
138+
# Skip loading extra bias for GPTQ models.
139+
if name.endswith(".bias") and name not in params_dict:
140+
continue
141+
142+
param = params_dict[name]
143+
weight_loader = param.weight_loader
144+
weight_loader(param, loaded_weight, shard_id)
145+
break
146+
else:
147+
for mapping in expert_params_mapping:
148+
param_name, weight_name, expert_id, shard_id = mapping
149+
if weight_name not in name:
150+
continue
151+
name = name.replace(weight_name, param_name)
152+
153+
param = params_dict[name]
154+
weight_loader = param.weight_loader
155+
weight_loader(
156+
param,
157+
loaded_weight,
158+
name,
159+
shard_id=shard_id,
160+
expert_id=expert_id,
161+
)
162+
break
163+
else:
164+
# if PP disabled then draft will share embed with target
165+
if get_pp_group().world_size == 1 and \
166+
"embed_tokens." in name:
167+
continue
168+
169+
# Skip loading extra bias for GPTQ models.
170+
if name.endswith(".bias") and name not in params_dict:
171+
continue
172+
173+
# Remapping the name of FP8 kv-scale.
174+
name = maybe_remap_kv_scale_name(name, params_dict)
175+
if name is None:
176+
continue
177+
178+
param = params_dict[name]
179+
weight_loader = getattr(param, "weight_loader",
180+
default_weight_loader)
181+
weight_loader(param, loaded_weight)
182+
loaded_params.add(name)
183+
return loaded_params
184+
185+
186+
class EagleDeepseekForCausalLM(DeepseekV3ForCausalLM):
187+
188+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
189+
nn.Module.__init__(self)
190+
self.config = vllm_config. \
191+
speculative_config.draft_model_config.hf_config
192+
target_layer_num = vllm_config.model_config.get_num_layers(
193+
vllm_config.parallel_config)
194+
self.model = DeepseekV2Model(vllm_config=vllm_config,
195+
prefix="model",
196+
start_layer_id=target_layer_num)
197+
198+
logit_scale = getattr(self.config, "logit_scale", 1.0)
199+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
200+
scale=logit_scale)
201+
202+
def forward(
203+
self,
204+
input_ids: torch.Tensor,
205+
positions: torch.Tensor,
206+
hidden_states: torch.Tensor,
207+
) -> tuple[torch.Tensor, torch.Tensor]:
208+
return self.model(input_ids, positions, hidden_states)
209+
210+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
211+
loader = AutoWeightsLoader(
212+
self,
213+
skip_prefixes=None,
214+
)
215+
216+
model_weights = {}
217+
for name, loaded_weight in weights:
218+
if "lm_head" not in name:
219+
name = "model." + name
220+
model_weights[name] = loaded_weight
221+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
247247
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
248248
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
249+
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekForCausalLM"),
249250
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
250251
"MedusaModel": ("medusa", "Medusa"),
251252
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),

0 commit comments

Comments
 (0)