Skip to content

【main】 Support SP for qwen2.5 and qwen3 moe #1761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,47 @@ def test_models_distributed_topk() -> None:
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ENABLE_SP": "1"})
def test_fluash_comm1_for_qwen3_moe() -> None:
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
]
dtype = "half"
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)

with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ENABLE_SP": "1"})
def test_fluash_comm1_for_qwen2_5() -> None:
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
]
dtype = "half"
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)

with VllmRunner(
"Qwen/Qwen2.5-0.5B-Instruct",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
64 changes: 64 additions & 0 deletions tests/ut/ops/test_flash_comm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#


import torch
import importlib
from tests.ut.base import TestBase
from unittest.mock import MagicMock, patch

from vllm.distributed.parallel_state import GroupCoordinator

from vllm_ascend.ops import sequence_parallel


class Test_Flash_Comm1(TestBase):

@patch('vllm.distributed.tensor_model_parallel_all_gather')
@patch('vllm.distributed.tensor_model_parallel_reduce_scatter')
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_test_flash_comm1(self, mock_TP,
mock_tensor_model_parallel_reduce_scatter,
mock_tensor_model_parallel_all_gather):
with patch('vllm.distributed.get_tp_group',
return_value=MagicMock(world_size=4, rank_in_group=0)) as mock_get_tp_group:
num_tokens = 9
hidden_size = 128
tp_size = 4
hidden_states = torch.randn(num_tokens, hidden_size)

mock_tp_group = mock_get_tp_group.return_value
assert mock_tp_group.world_size == 4 # 手动断言属性存在
assert mock_tp_group.rank_in_group == 0

lengths_sum_unpadding = hidden_states.shape[0]
lengths_sum_padding = ((lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
padding_flag = True
pad_size = lengths_sum_padding - lengths_sum_unpadding
importlib.reload(sequence_parallel)
_metadata_for_padding = sequence_parallel.MetadataForPadding(lengths_sum_unpadding=lengths_sum_unpadding,
lengths_sum_padding=lengths_sum_padding,
padding_flag=padding_flag,
pad_size=pad_size,
not_dummy_and_is_prefill=True)

mock_tensor_model_parallel_reduce_scatter.return_value = torch.randn(lengths_sum_padding // tp_size, hidden_size)
mock_tensor_model_parallel_all_gather.return_value = torch.randn(lengths_sum_padding, hidden_size)

hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states)
output = _metadata_for_padding.allgather_unpadding_aligned(hidden_states)

self.assertEqual(output.shape, (num_tokens, hidden_size))
6 changes: 4 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def build(self,
max_query_len,
common_prefix_len,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
with_prefill_across_dp: bool = False,
is_only_prefill:bool = False):

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
Expand Down Expand Up @@ -207,7 +208,8 @@ def build(self,
attn_mask=attn_mask,
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
with_prefill_across_dp=with_prefill_across_dp,
is_only_prefill=is_only_prefill)
return attn_metadata


Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
"VLLM_ENABLE_SP":
lambda: bool(int(os.getenv("VLLM_ENABLE_SP", '0')))
}

# end-env-vars-definition
Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def register_model():

ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
"vllm_ascend.models.qwen3_moe:AscendQwen3MoeForCausalLM")

ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")

ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_ascend.models.qwen2:AscendQwen2ForCausalLM")
204 changes: 204 additions & 0 deletions vllm_ascend/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.

from typing import Optional, Union

import torch
from torch import nn
from transformers import Qwen2Config

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
from vllm.model_executor.models.qwen2 import (Qwen2ForCausalLM,
Qwen2Model,
Qwen2DecoderLayer)

from vllm_ascend.ops.sequence_parallel import init_metadata_for_sp, MetadataForPadding
import vllm_ascend.envs as envs_ascend


class AscendQwen2DecoderLayer(Qwen2DecoderLayer):

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> tuple[torch.Tensor, torch.Tensor]:

# To prevent precision issues during the decoder phase when only prefilling enables SP
if not envs_ascend.VLLM_ENABLE_SP:
self.self_attn.o_proj.reduce_results = True
self.mlp.down_proj.reduce_results = True
else:

Check failure on line 64 in vllm_ascend/models/qwen2.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Item "None" of "Optional[MetadataForPadding]" has no attribute "not_dummy_and_is_prefill" [union-attr]
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill

Check failure on line 65 in vllm_ascend/models/qwen2.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Item "None" of "Optional[MetadataForPadding]" has no attribute "not_dummy_and_is_prefill" [union-attr]
self.mlp.down_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill

if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)

hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states)

hidden_states = self.mlp(hidden_states)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states)

return hidden_states, residual


@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class AscendQwen2Model(Qwen2Model):

def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = AscendQwen2DecoderLayer):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=decoder_layer_type)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states)

return hidden_states


class AscendQwen2ForCausalLM(Qwen2ForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = AscendQwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(input_ids)

hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states
Loading
Loading