Skip to content

Commit 6981998

Browse files
committed
[Model] Support deepseek with eagle
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent 10be209 commit 6981998

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import VllmConfig
12+
from vllm.distributed.parallel_state import get_pp_group
13+
from vllm.model_executor.layers.fused_moe import FusedMoE
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.vocab_parallel_embedding import (
17+
ParallelLMHead, VocabParallelEmbedding)
18+
from vllm.model_executor.model_loader.weight_utils import (
19+
default_weight_loader, maybe_remap_kv_scale_name)
20+
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
21+
DeepseekV3ForCausalLM)
22+
from vllm.model_executor.sampling_metadata import SamplingMetadata
23+
24+
from .utils import AutoWeightsLoader, maybe_prefix
25+
26+
27+
@support_torch_compile
28+
class DeepseekV2Model(nn.Module):
29+
30+
def __init__(
31+
self,
32+
*,
33+
vllm_config: VllmConfig,
34+
prefix: str = "",
35+
start_layer_id: int = 0,
36+
) -> None:
37+
super().__init__()
38+
self.config = vllm_config. \
39+
speculative_config.draft_model_config.hf_config
40+
model_config = vllm_config.model_config
41+
cache_config = vllm_config.cache_config
42+
quant_config = vllm_config.quant_config
43+
self.vocab_size = self.config.vocab_size
44+
45+
self.embed_tokens = VocabParallelEmbedding(
46+
self.config.vocab_size,
47+
self.config.hidden_size,
48+
quant_config=quant_config,
49+
prefix=maybe_prefix(prefix, "embed_tokens"),
50+
)
51+
52+
self.layers = nn.ModuleList([
53+
DeepseekV2DecoderLayer(
54+
self.config,
55+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
56+
model_config=model_config,
57+
cache_config=cache_config,
58+
quant_config=quant_config,
59+
) for i in range(self.config.num_hidden_layers)
60+
])
61+
62+
self.fc = nn.Linear(
63+
self.config.model.hidden_size * 2,
64+
self.config.model.hidden_size,
65+
bias=False,
66+
)
67+
68+
self.enorm = RMSNorm(self.config.hidden_size,
69+
eps=self.config.rms_norm_eps)
70+
self.hnorm = RMSNorm(self.config.hidden_size,
71+
eps=self.config.rms_norm_eps)
72+
self.norm = RMSNorm(self.config.hidden_size,
73+
eps=self.config.rms_norm_eps)
74+
75+
def forward(
76+
self,
77+
input_ids: torch.Tensor,
78+
positions: torch.Tensor,
79+
hidden_states: torch.Tensor,
80+
) -> tuple[torch.Tensor, torch.Tensor]:
81+
input_embeds = self.embed_tokens(input_ids)
82+
83+
inputs = torch.cat(
84+
[self.enorm(input_embeds),
85+
self.hnorm(hidden_states)], dim=-1)
86+
hidden_states = self.fc(inputs)
87+
88+
# masking inputs at position=0
89+
hidden_states[positions == 0] = 0
90+
residual = None
91+
for layer in self.layers:
92+
hidden_states, residual = layer(
93+
positions,
94+
hidden_states,
95+
residual,
96+
)
97+
hidden_states, _ = self.norm(hidden_states, residual)
98+
return hidden_states, hidden_states
99+
100+
def load_weights(self, weights: Iterable[tuple[str,
101+
torch.Tensor]]) -> set[str]:
102+
stacked_params_mapping = [
103+
# (param_name, shard_name, shard_id)
104+
("gate_up_proj", "gate_proj", 0),
105+
("gate_up_proj", "up_proj", 1),
106+
]
107+
108+
# Params for weights, fp8 weight scales, fp8 activation scales
109+
# (param_name, weight_name, expert_id, shard_id)
110+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
111+
ckpt_gate_proj_name="gate_proj",
112+
ckpt_down_proj_name="down_proj",
113+
ckpt_up_proj_name="up_proj",
114+
num_experts=self.config.n_routed_experts)
115+
116+
params_dict = dict(self.named_parameters())
117+
loaded_params: set[str] = set()
118+
for name, loaded_weight in weights:
119+
if "rotary_emb.inv_freq" in name:
120+
continue
121+
122+
for param_name, weight_name, shard_id in stacked_params_mapping:
123+
# Skip non-stacked layers and experts (experts handled below).
124+
if weight_name not in name:
125+
continue
126+
# We have mlp.experts[0].gate_proj in the checkpoint.
127+
# Since we handle the experts below in expert_params_mapping,
128+
# we need to skip here BEFORE we update the name, otherwise
129+
# name will be updated to mlp.experts[0].gate_up_proj, which
130+
# will then be updated below in expert_params_mapping
131+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
132+
if ("mlp.experts." in name) and name not in params_dict:
133+
continue
134+
name = name.replace(weight_name, param_name)
135+
# Skip loading extra bias for GPTQ models.
136+
if name.endswith(".bias") and name not in params_dict:
137+
continue
138+
139+
param = params_dict[name]
140+
weight_loader = param.weight_loader
141+
weight_loader(param, loaded_weight, shard_id)
142+
break
143+
else:
144+
for mapping in expert_params_mapping:
145+
param_name, weight_name, expert_id, shard_id = mapping
146+
if weight_name not in name:
147+
continue
148+
name = name.replace(weight_name, param_name)
149+
150+
param = params_dict[name]
151+
weight_loader = param.weight_loader
152+
weight_loader(
153+
param,
154+
loaded_weight,
155+
name,
156+
shard_id=shard_id,
157+
expert_id=expert_id,
158+
)
159+
break
160+
else:
161+
# if PP disabled then draft will share embed with target
162+
if get_pp_group().world_size == 1 and \
163+
"embed_tokens." in name:
164+
continue
165+
166+
# Skip loading extra bias for GPTQ models.
167+
if name.endswith(".bias") and name not in params_dict:
168+
continue
169+
170+
# Remapping the name of FP8 kv-scale.
171+
name = maybe_remap_kv_scale_name(name, params_dict)
172+
if name is None:
173+
continue
174+
175+
param = params_dict[name]
176+
weight_loader = getattr(param, "weight_loader",
177+
default_weight_loader)
178+
weight_loader(param, loaded_weight)
179+
loaded_params.add(name)
180+
return loaded_params
181+
182+
183+
class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
184+
185+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
186+
nn.Module.__init__(self)
187+
self.config = vllm_config. \
188+
speculative_config.draft_model_config.hf_config
189+
quant_config = vllm_config.quant_config
190+
target_layer_num = vllm_config.model_config.get_num_layers(
191+
vllm_config.parallel_config)
192+
self.model = DeepseekV2Model(vllm_config=vllm_config,
193+
prefix="model",
194+
start_layer_id=target_layer_num)
195+
196+
self.lm_head = ParallelLMHead(self.config.vocab_size,
197+
self.config.hidden_size,
198+
quant_config=quant_config)
199+
200+
logit_scale = getattr(self.config, "logit_scale", 1.0)
201+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
202+
scale=logit_scale)
203+
204+
def forward(
205+
self,
206+
input_ids: torch.Tensor,
207+
positions: torch.Tensor,
208+
hidden_states: torch.Tensor,
209+
) -> tuple[torch.Tensor, torch.Tensor]:
210+
return self.model(input_ids, positions, hidden_states)
211+
212+
def compute_logits(
213+
self,
214+
hidden_states: torch.Tensor,
215+
sampling_metadata: SamplingMetadata,
216+
) -> Optional[torch.Tensor]:
217+
logits = self.logits_processor(self.lm_head, hidden_states,
218+
sampling_metadata)
219+
return logits
220+
221+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
222+
loader = AutoWeightsLoader(
223+
self,
224+
skip_prefixes=None,
225+
)
226+
227+
model_weights = {}
228+
for name, loaded_weight in weights:
229+
if "lm_head" not in name:
230+
name = "model." + name
231+
model_weights[name] = loaded_weight
232+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
247247
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
248248
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
249+
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
249250
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
250251
"MedusaModel": ("medusa", "Medusa"),
251252
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),

0 commit comments

Comments
 (0)