Skip to content

Commit c0429b4

Browse files
committed
fix
1 parent ecf25e8 commit c0429b4

File tree

5 files changed

+40
-314
lines changed

5 files changed

+40
-314
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 8 additions & 294 deletions
Original file line numberDiff line numberDiff line change
@@ -21,301 +21,19 @@
2121

2222
import torch
2323
from torch import nn
24-
from transformers import PretrainedConfig
25-
from vllm.compilation.decorators import support_torch_compile
26-
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
27-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
28-
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
29-
get_tp_group)
30-
from vllm.forward_context import get_forward_context
24+
from vllm.config import VllmConfig
3125
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
32-
from vllm.model_executor.layers.layernorm import RMSNorm
33-
from vllm.model_executor.layers.linear import ReplicatedLinear
3426
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35-
from vllm.model_executor.layers.quantization import QuantizationConfig
36-
from vllm.model_executor.layers.vocab_parallel_embedding import (
37-
ParallelLMHead, VocabParallelEmbedding)
27+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
3828
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
3929
SupportsLoRA, SupportsPP)
40-
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
41-
Qwen3MoeDecoderLayer,
30+
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
4231
Qwen3MoeForCausalLM,
43-
Qwen3MoeMLP, Qwen3MoeModel,
32+
Qwen3MoeModel,
4433
Qwen3MoeSparseMoeBlock)
45-
from vllm.model_executor.models.utils import (
46-
PPMissingLayer, extract_layer_index,
47-
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
34+
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
4835
from vllm.sequence import IntermediateTensors
4936

50-
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
51-
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
52-
init_metadata_for_sp)
53-
54-
55-
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
56-
57-
def __init__(
58-
self,
59-
config: PretrainedConfig,
60-
quant_config: Optional[QuantizationConfig] = None,
61-
prefix: str = "",
62-
):
63-
nn.Module.__init__(self)
64-
self.tp_size = get_tensor_model_parallel_world_size()
65-
if self.tp_size > config.num_experts:
66-
raise ValueError(
67-
f"Tensor parallel size {self.tp_size} is greater than "
68-
f"the number of experts {config.num_experts}.")
69-
70-
self.gate = ReplicatedLinear(
71-
config.hidden_size,
72-
config.num_experts,
73-
bias=False,
74-
quant_config=None,
75-
prefix=f"{prefix}.gate",
76-
)
77-
78-
self.experts = AscendFusedMoE(
79-
num_experts=config.num_experts,
80-
top_k=config.num_experts_per_tok,
81-
hidden_size=config.hidden_size,
82-
intermediate_size=config.moe_intermediate_size,
83-
reduce_results=False,
84-
renormalize=config.norm_topk_prob,
85-
quant_config=quant_config,
86-
prefix=f"{prefix}.experts",
87-
)
88-
89-
self.top_k = config.num_experts_per_tok
90-
91-
self.dp_size = get_dp_group().world_size
92-
93-
self.tp_group = get_tp_group().device_group
94-
self.tp_rank = get_tp_group().rank_in_group
95-
self.ep_group = get_ep_group()
96-
97-
self.params_dtype = torch.get_default_dtype()
98-
99-
def forward(
100-
self,
101-
hidden_states,
102-
attn_metadata=None,
103-
_metadata_for_padding: Optional[MetadataForPadding] = None,
104-
):
105-
if attn_metadata is None:
106-
attn_metadata = get_forward_context().attn_metadata
107-
# when profile runs, force experts to load balanced tokens
108-
# to avoid high memory consumption on a single rank.
109-
enable_force_load_balance = get_forward_context().in_profile_run
110-
is_prefill = get_forward_context().with_prefill
111-
112-
# router_logits: (num_tokens, n_experts)
113-
router_logits, _ = self.gate(hidden_states)
114-
115-
hidden_states = self.experts(
116-
hidden_states=hidden_states,
117-
router_logits=router_logits,
118-
is_prefill=is_prefill,
119-
top_k=self.top_k,
120-
enable_force_load_balance=enable_force_load_balance,
121-
shared_experts=None,
122-
_metadata_for_padding=_metadata_for_padding,
123-
)
124-
125-
return hidden_states
126-
127-
128-
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
129-
130-
def __init__(
131-
self,
132-
config: PretrainedConfig,
133-
cache_config: Optional[CacheConfig] = None,
134-
quant_config: Optional[QuantizationConfig] = None,
135-
vllm_config: Optional[VllmConfig] = None,
136-
prefix: str = "",
137-
) -> None:
138-
139-
nn.Module.__init__(self)
140-
self.hidden_size = config.hidden_size
141-
rope_theta = getattr(config, "rope_theta", 10000)
142-
rope_scaling = getattr(config, "rope_scaling", None)
143-
max_position_embeddings = getattr(config, "max_position_embeddings",
144-
8192)
145-
self.self_attn = Qwen3MoeAttention(
146-
hidden_size=self.hidden_size,
147-
num_heads=config.num_attention_heads,
148-
num_kv_heads=config.num_key_value_heads,
149-
rope_theta=rope_theta,
150-
rope_scaling=rope_scaling,
151-
max_position_embeddings=max_position_embeddings,
152-
rms_norm_eps=config.rms_norm_eps,
153-
qkv_bias=getattr(config, 'attention_bias', False),
154-
head_dim=getattr(config, 'head_dim', None),
155-
cache_config=cache_config,
156-
quant_config=quant_config,
157-
prefix=f"{prefix}.self_attn",
158-
)
159-
160-
# `mlp_only_layers` in the config.
161-
layer_idx = extract_layer_index(prefix)
162-
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
163-
config.mlp_only_layers)
164-
self.use_aclgraph = (vllm_config is not None
165-
and vllm_config.compilation_config.level
166-
== CompilationLevel.PIECEWISE
167-
and not vllm_config.model_config.enforce_eager)
168-
if (layer_idx not in mlp_only_layers) and (
169-
config.num_experts > 0 and
170-
(layer_idx + 1) % config.decoder_sparse_step == 0):
171-
if not self.use_aclgraph:
172-
# FIXME: custom sparse moe block doesn't work with aclgraph.
173-
self.mlp = CustomSparseMoeBlock(config=config,
174-
quant_config=quant_config,
175-
prefix=f"{prefix}.mlp")
176-
else:
177-
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
178-
quant_config=quant_config,
179-
prefix=f"{prefix}.mlp")
180-
else:
181-
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
182-
intermediate_size=config.intermediate_size,
183-
hidden_act=config.hidden_act,
184-
quant_config=quant_config,
185-
prefix=f"{prefix}.mlp")
186-
self.input_layernorm = RMSNorm(config.hidden_size,
187-
eps=config.rms_norm_eps)
188-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
189-
eps=config.rms_norm_eps)
190-
191-
self.enable_sequence_parallelism = (
192-
vllm_config.compilation_config.pass_config.
193-
enable_sequence_parallelism if vllm_config is not None else False)
194-
195-
def forward(
196-
self,
197-
positions: torch.Tensor,
198-
hidden_states: torch.Tensor,
199-
residual: Optional[torch.Tensor],
200-
_metadata_for_padding: Optional[MetadataForPadding] = None,
201-
) -> torch.Tensor:
202-
203-
# To prevent precision issues during the decoder phase when only prefilling enables SP
204-
if not self.enable_sequence_parallelism:
205-
self.self_attn.o_proj.reduce_results = True
206-
else:
207-
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208-
209-
# Self Attention
210-
if residual is None:
211-
residual = hidden_states
212-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
213-
residual = _metadata_for_padding.padding_slice(residual)
214-
215-
hidden_states = self.input_layernorm(hidden_states)
216-
else:
217-
hidden_states, residual = self.input_layernorm(
218-
hidden_states, residual)
219-
220-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
221-
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
222-
hidden_states)
223-
224-
hidden_states = self.self_attn(
225-
positions=positions,
226-
hidden_states=hidden_states,
227-
)
228-
229-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
230-
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
231-
hidden_states)
232-
233-
# Fully Connected
234-
hidden_states, residual = self.post_attention_layernorm(
235-
hidden_states, residual)
236-
237-
if not self.use_aclgraph:
238-
hidden_states = self.mlp(
239-
hidden_states, _metadata_for_padding=_metadata_for_padding)
240-
else:
241-
hidden_states = self.mlp(hidden_states)
242-
243-
return hidden_states, residual
244-
245-
246-
@support_torch_compile
247-
class CustomQwen3MoeModel(Qwen3MoeModel):
248-
249-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
250-
nn.Module.__init__(self)
251-
config = vllm_config.model_config.hf_config
252-
cache_config = vllm_config.cache_config
253-
quant_config = vllm_config.quant_config
254-
255-
parallel_config = vllm_config.parallel_config
256-
eplb_config = parallel_config.eplb_config
257-
self.num_redundant_experts = eplb_config.num_redundant_experts
258-
self.padding_idx = config.pad_token_id
259-
self.vocab_size = config.vocab_size
260-
self.config = config
261-
self.embed_tokens = VocabParallelEmbedding(
262-
config.vocab_size,
263-
config.hidden_size,
264-
prefix=f"{prefix}.embed_tokens")
265-
self.start_layer, self.end_layer, self.layers = make_layers(
266-
config.num_hidden_layers,
267-
lambda prefix: CustomQwen3MoeDecoderLayer(
268-
config=config,
269-
cache_config=cache_config,
270-
quant_config=quant_config,
271-
vllm_config=vllm_config,
272-
prefix=prefix),
273-
prefix=f"{prefix}.layers",
274-
)
275-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276-
self.make_empty_intermediate_tensors = (
277-
make_empty_intermediate_tensors_factory(
278-
["hidden_states", "residual"], config.hidden_size))
279-
280-
def forward(
281-
self,
282-
input_ids: torch.Tensor,
283-
positions: torch.Tensor,
284-
intermediate_tensors: Optional[IntermediateTensors] = None,
285-
inputs_embeds: Optional[torch.Tensor] = None,
286-
_metadata_for_padding: Optional[MetadataForPadding] = None,
287-
) -> Union[torch.Tensor, IntermediateTensors]:
288-
if get_pp_group().is_first_rank:
289-
if inputs_embeds is not None:
290-
hidden_states = inputs_embeds
291-
else:
292-
hidden_states = self.get_input_embeddings(input_ids)
293-
residual = None
294-
else:
295-
assert intermediate_tensors is not None
296-
hidden_states = intermediate_tensors["hidden_states"]
297-
residual = intermediate_tensors["residual"]
298-
for i in range(self.start_layer, self.end_layer):
299-
layer = self.layers[i]
300-
hidden_states, residual = layer(
301-
positions,
302-
hidden_states,
303-
residual,
304-
_metadata_for_padding=_metadata_for_padding)
305-
if not get_pp_group().is_last_rank:
306-
return IntermediateTensors({
307-
"hidden_states": hidden_states,
308-
"residual": residual
309-
})
310-
311-
hidden_states, _ = self.norm(hidden_states, residual)
312-
313-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
314-
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
315-
hidden_states)
316-
317-
return hidden_states
318-
31937

32038
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
32139
packed_modules_mapping = {
@@ -341,8 +59,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
34159
quant_config = vllm_config.quant_config
34260
self.config = config
34361
self.quant_config = quant_config
344-
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
345-
prefix=maybe_prefix(prefix, "model"))
62+
self.model = Qwen3MoeModel(vllm_config=vllm_config,
63+
prefix=maybe_prefix(prefix, "model"))
34664
self.lm_head = ParallelLMHead(config.vocab_size,
34765
config.hidden_size,
34866
quant_config=quant_config,
@@ -352,8 +70,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
35270
self.logits_processor = LogitsProcessor(config.vocab_size)
35371
self.make_empty_intermediate_tensors = (
35472
self.model.make_empty_intermediate_tensors)
355-
356-
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
35773
# Set MoE hyperparameters
35874
self.expert_weights: list[torch.Tensor] = []
35975

@@ -382,8 +98,6 @@ def forward(
38298
intermediate_tensors: Optional[IntermediateTensors] = None,
38399
inputs_embeds: Optional[torch.Tensor] = None,
384100
) -> Union[torch.Tensor, IntermediateTensors]:
385-
_metadata_for_padding = init_metadata_for_sp(
386-
input_ids, self.enable_sequence_parallelism)
387101
hidden_states = self.model(input_ids, positions, intermediate_tensors,
388-
inputs_embeds, _metadata_for_padding)
102+
inputs_embeds)
389103
return hidden_states

0 commit comments

Comments
 (0)