Skip to content

Commit 7a5c10c

Browse files
committed
adapt glm4.5 moe
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 741a8cf commit 7a5c10c

File tree

3 files changed

+356
-0
lines changed

3 files changed

+356
-0
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class FusedMoEState(Enum):
2323

2424
# TODO(zzzzwwjj): add soc_version to choose branch
2525
def get_fused_moe_state(ep_size: int, with_prefill: bool):
26+
return FusedMoEState.AllGather
2627
enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2
2728
if ep_size == 1:
2829
return FusedMoEState.AllGather

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def register_model():
1313
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1414
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
1515
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
16+
from .glm4_moe import CustomGlm4MoeForCausalLM
1617

1718
ModelRegistry.register_model(
1819
"DeepSeekMTPModel",
@@ -64,3 +65,6 @@ def register_model():
6465

6566
ModelRegistry.register_model(
6667
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
68+
69+
ModelRegistry.register_model(
70+
"Glm4MoeForCausalLM", "vllm_ascend.models.glm4_moe:CustomGlm4MoeForCausalLM")

vllm_ascend/models/glm4_moe.py

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
import typing
2+
from collections.abc import Callable, Iterable
3+
from typing import Any, Optional, Union
4+
5+
import torch
6+
from torch import nn
7+
from transformers import PretrainedConfig
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
11+
from vllm.distributed import get_pp_group
12+
from vllm.logger import init_logger
13+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
14+
QKVParallelLinear,
15+
ReplicatedLinear,
16+
RowParallelLinear)
17+
from vllm.model_executor.layers.fused_moe import FusedMoE
18+
from vllm.model_executor.layers.layernorm import RMSNorm
19+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20+
from vllm.model_executor.layers.quantization import QuantizationConfig
21+
from vllm.model_executor.layers.vocab_parallel_embedding import (
22+
ParallelLMHead, VocabParallelEmbedding)
23+
from vllm.sequence import IntermediateTensors
24+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
25+
26+
from vllm.model_executor.models.utils import (PPMissingLayer, make_empty_intermediate_tensors_factory,
27+
make_layers, maybe_prefix)
28+
from vllm.distributed import (get_dp_group, get_pp_group,
29+
get_tensor_model_parallel_rank,
30+
get_tensor_model_parallel_world_size,
31+
get_tp_group, split_tensor_along_last_dim,
32+
tensor_model_parallel_reduce_scatter)
33+
from vllm.model_executor.models.glm4_moe import Glm4MoeForCausalLM, Glm4MoeDecoderLayer, Glm4MoeModel, Glm4MoeAttention, Glm4MoeMLP
34+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
35+
from vllm.forward_context import get_forward_context
36+
logger = init_logger(__name__)
37+
38+
39+
class CustomGlm4MoE(nn.Module):
40+
41+
top_k: int
42+
43+
def __init__(
44+
self,
45+
config: PretrainedConfig,
46+
quant_config: Optional[QuantizationConfig] = None,
47+
prefix: str = "",
48+
enable_eplb: bool = False,
49+
):
50+
super().__init__()
51+
self.tp_size = get_tensor_model_parallel_world_size()
52+
self.routed_scaling_factor = config.routed_scaling_factor
53+
self.n_shared_experts = config.n_shared_experts
54+
if self.tp_size > config.n_routed_experts:
55+
raise ValueError(
56+
f"Tensor parallel size {self.tp_size} is greater than "
57+
f"the number of experts {config.n_routed_experts}.")
58+
59+
if config.hidden_act != "silu":
60+
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
61+
"Only silu is supported for now.")
62+
63+
self.gate = ReplicatedLinear(
64+
config.hidden_size,
65+
config.n_routed_experts,
66+
bias=False,
67+
quant_config=None,
68+
prefix=f"{prefix}.gate")
69+
70+
self.gate.e_score_correction_bias = (nn.Parameter(
71+
torch.empty(config.n_routed_experts,)))
72+
73+
self.experts = AscendFusedMoE(
74+
num_experts=config.n_routed_experts,
75+
top_k=config.num_experts_per_tok,
76+
hidden_size=config.hidden_size,
77+
intermediate_size=config.moe_intermediate_size,
78+
reduce_results=False,
79+
renormalize=config.norm_topk_prob,
80+
quant_config=quant_config,
81+
use_grouped_topk=True,
82+
num_expert_group=config.n_group,
83+
topk_group=config.topk_group,
84+
prefix=f"{prefix}.experts",
85+
scoring_func=config.scoring_func,
86+
e_score_correction_bias=self.gate.e_score_correction_bias)
87+
88+
if config.n_shared_experts is not None:
89+
intermediate_size = (config.moe_intermediate_size *
90+
config.n_shared_experts)
91+
self.shared_experts = Glm4MoeMLP(
92+
hidden_size=config.hidden_size,
93+
intermediate_size=intermediate_size,
94+
hidden_act=config.hidden_act,
95+
quant_config=quant_config,
96+
reduce_results=True,
97+
prefix=f"{prefix}.shared_experts",
98+
)
99+
else:
100+
self.shared_experts = None # type: ignore
101+
CustomGlm4MoE.top_k = config.num_experts_per_tok
102+
103+
self.dp_size = get_dp_group().world_size
104+
105+
self.tp_group = get_tp_group().device_group
106+
self.tp_rank = get_tp_group().rank_in_group
107+
108+
def forward(
109+
self,
110+
hidden_states: torch.Tensor) -> torch.Tensor:
111+
forward_context = get_forward_context()
112+
# when profile runs, force experts to load balanced tokens
113+
# to avoid high memory consumption on a single rank.
114+
enable_force_load_balance = forward_context.in_profile_run
115+
116+
is_prefill = forward_context.with_prefill
117+
118+
# router_logits: (num_tokens, n_experts)
119+
router_logits, _ = self.gate(hidden_states)
120+
121+
experts_hidden_states = self.experts(
122+
hidden_states=hidden_states,
123+
router_logits=router_logits,
124+
is_prefill=is_prefill,
125+
top_k=CustomGlm4MoE.top_k,
126+
enable_force_load_balance=enable_force_load_balance,
127+
shared_experts=self.shared_experts,
128+
gate=None,
129+
)
130+
131+
hidden_states = (
132+
experts_hidden_states[0] * self.routed_scaling_factor +
133+
experts_hidden_states[1])
134+
135+
return hidden_states
136+
137+
class CustomGlm4MoeDecoderLayer(Glm4MoeDecoderLayer):
138+
139+
def __init__(
140+
self,
141+
config: PretrainedConfig,
142+
cache_config: Optional[CacheConfig] = None,
143+
quant_config: Optional[QuantizationConfig] = None,
144+
prefix: str = "",
145+
enable_eplb: bool = False,
146+
) -> None:
147+
nn.Module.__init__(self)
148+
self.hidden_size = config.hidden_size
149+
rope_theta = getattr(config, "rope_theta", 10000)
150+
rope_scaling = getattr(config, "rope_scaling", None)
151+
max_position_embeddings = getattr(config, "max_position_embeddings",
152+
131072)
153+
# DecoderLayers are created with `make_layers` which passes the prefix
154+
# with the layer's index.
155+
layer_idx = int(prefix.split(sep='.')[-1])
156+
self.layer_idx = layer_idx
157+
158+
self.self_attn = Glm4MoeAttention(
159+
config=config,
160+
hidden_size=self.hidden_size,
161+
num_heads=config.num_attention_heads,
162+
num_kv_heads=config.num_key_value_heads,
163+
rope_theta=rope_theta,
164+
rope_scaling=rope_scaling,
165+
max_position_embeddings=max_position_embeddings,
166+
head_dim=config.head_dim,
167+
rms_norm_eps=config.rms_norm_eps,
168+
qkv_bias=config.attention_bias,
169+
cache_config=cache_config,
170+
quant_config=quant_config,
171+
prefix=f"{prefix}.self_attn",
172+
use_qk_norm=config.use_qk_norm,
173+
)
174+
175+
if (config.n_routed_experts is not None
176+
and layer_idx >= config.first_k_dense_replace):
177+
self.mlp = AscendSparseMoeBlock(
178+
config=config,
179+
quant_config=quant_config,
180+
prefix=f"{prefix}.mlp",
181+
)
182+
else:
183+
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
184+
intermediate_size=config.intermediate_size,
185+
hidden_act=config.hidden_act,
186+
quant_config=quant_config,
187+
prefix=f"{prefix}.mlp")
188+
189+
self.input_layernorm = RMSNorm(config.hidden_size,
190+
eps=config.rms_norm_eps)
191+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
192+
eps=config.rms_norm_eps)
193+
self.routed_scaling_factor = config.routed_scaling_factor
194+
195+
def forward(
196+
self,
197+
positions: torch.Tensor,
198+
hidden_states: torch.Tensor,
199+
residual: Optional[torch.Tensor],
200+
) -> tuple[torch.Tensor, torch.Tensor]:
201+
if residual is None:
202+
residual = hidden_states
203+
hidden_states = self.input_layernorm(hidden_states)
204+
else:
205+
hidden_states, residual = self.input_layernorm(
206+
hidden_states, residual)
207+
hidden_states = self.self_attn(positions=positions,
208+
hidden_states=hidden_states)
209+
hidden_states, residual = self.post_attention_layernorm(
210+
hidden_states, residual)
211+
hidden_states = self.mlp(hidden_states)
212+
return hidden_states, residual
213+
214+
215+
@support_torch_compile
216+
class CustomGlm4MoeModel(Glm4MoeModel):
217+
218+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
219+
nn.Module.__init__(self)
220+
config = vllm_config.model_config.hf_config
221+
cache_config = vllm_config.cache_config
222+
quant_config = vllm_config.quant_config
223+
# enable_eplb = vllm_config.parallel_config.enable_eplb
224+
self.config = config
225+
226+
self.vocab_size = config.vocab_size
227+
228+
if get_pp_group().is_first_rank:
229+
self.embed_tokens = VocabParallelEmbedding(
230+
config.vocab_size,
231+
config.hidden_size,
232+
prefix=f"{prefix}.embed_tokens")
233+
else:
234+
self.embed_tokens = PPMissingLayer()
235+
236+
self.start_layer, self.end_layer, self.layers = make_layers(
237+
config.num_hidden_layers,
238+
lambda prefix: CustomGlm4MoeDecoderLayer(
239+
config=config,
240+
cache_config=cache_config,
241+
quant_config=quant_config,
242+
prefix=prefix,
243+
# enable_eplb=enable_eplb,
244+
),
245+
prefix=f"{prefix}.layers")
246+
247+
if get_pp_group().is_last_rank:
248+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249+
else:
250+
self.norm = PPMissingLayer()
251+
self.make_empty_intermediate_tensors = (
252+
make_empty_intermediate_tensors_factory(
253+
["hidden_states", "residual"], config.hidden_size))
254+
255+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
256+
return self.embed_tokens(input_ids)
257+
258+
def forward(
259+
self,
260+
input_ids: torch.Tensor,
261+
positions: torch.Tensor,
262+
intermediate_tensors: Optional[IntermediateTensors] = None,
263+
inputs_embeds: Optional[torch.Tensor] = None,
264+
) -> Union[torch.Tensor, IntermediateTensors]:
265+
if get_pp_group().is_first_rank:
266+
if inputs_embeds is not None:
267+
hidden_states = inputs_embeds
268+
else:
269+
hidden_states = self.get_input_embeddings(input_ids)
270+
residual = None
271+
else:
272+
assert intermediate_tensors is not None
273+
hidden_states = intermediate_tensors["hidden_states"]
274+
residual = intermediate_tensors["residual"]
275+
276+
for i in range(self.start_layer, self.end_layer):
277+
layer = self.layers[i]
278+
hidden_states, residual = layer(positions, hidden_states, residual)
279+
280+
if not get_pp_group().is_last_rank:
281+
return IntermediateTensors({
282+
"hidden_states": hidden_states,
283+
"residual": residual
284+
})
285+
286+
hidden_states, _ = self.norm(hidden_states, residual)
287+
return hidden_states
288+
289+
290+
class CustomGlm4MoeForCausalLM(Glm4MoeForCausalLM):
291+
292+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
293+
nn.Module.__init__(self)
294+
SupportsPP.__init__(self)
295+
SupportsLoRA.__init__(self)
296+
config = vllm_config.model_config.hf_config
297+
quant_config = vllm_config.quant_config
298+
self.config = config
299+
self.quant_config = quant_config
300+
self.model = CustomGlm4MoeModel(vllm_config=vllm_config,
301+
prefix=maybe_prefix(prefix, "model"))
302+
if get_pp_group().is_last_rank:
303+
self.lm_head = ParallelLMHead(config.vocab_size,
304+
config.hidden_size,
305+
quant_config=quant_config)
306+
else:
307+
self.lm_head = PPMissingLayer()
308+
if self.config.tie_word_embeddings:
309+
self.lm_head.weight = self.model.embed_tokens.weight
310+
self.logits_processor = LogitsProcessor(config.vocab_size)
311+
self.make_empty_intermediate_tensors = (
312+
self.model.make_empty_intermediate_tensors)
313+
self.expert_weights = []
314+
315+
# Set MoE hyperparameters
316+
self.num_moe_layers = (config.num_hidden_layers -
317+
config.first_k_dense_replace)
318+
self.num_expert_groups = config.n_group
319+
320+
self.moe_layers: list[FusedMoE] = []
321+
example_moe = None
322+
for layer in self.model.layers:
323+
if isinstance(layer, PPMissingLayer):
324+
continue
325+
326+
assert isinstance(layer, CustomGlm4MoeDecoderLayer)
327+
if isinstance(layer.mlp, AscendSparseMoeBlock):
328+
# Pick last one layer since the first ones may be dense layers.
329+
example_moe = layer.mlp
330+
self.moe_layers.append(layer.mlp.experts)
331+
332+
if example_moe is None:
333+
raise RuntimeError("No Glm4MoE layer found in model.layers.")
334+
335+
self.num_logical_experts = example_moe.n_logical_experts
336+
self.num_physical_experts = example_moe.n_physical_experts
337+
self.num_local_physical_experts = example_moe.n_local_physical_experts
338+
self.num_routed_experts = example_moe.n_routed_experts
339+
self.num_shared_experts = example_moe.n_shared_experts
340+
self.num_redundant_experts = example_moe.n_redundant_experts
341+
342+
def forward(
343+
self,
344+
input_ids: torch.Tensor,
345+
positions: torch.Tensor,
346+
intermediate_tensors: Optional[IntermediateTensors] = None,
347+
inputs_embeds: Optional[torch.Tensor] = None,
348+
) -> Union[torch.Tensor, IntermediateTensors]:
349+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
350+
inputs_embeds)
351+
return hidden_states

0 commit comments

Comments
 (0)