diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 58d0bf0ba0..9ec9d66bb1 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -167,3 +167,47 @@ def test_models_distributed_topk() -> None: distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ENABLE_SP": "1"}) +def test_fluash_comm1_for_qwen3_moe() -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + ] + dtype = "half" + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ENABLE_SP": "1"}) +def test_fluash_comm1_for_qwen2_5() -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + ] + dtype = "half" + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + "Qwen/Qwen2.5-0.5B-Instruct", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/ut/ops/test_flash_comm1.py b/tests/ut/ops/test_flash_comm1.py new file mode 100644 index 0000000000..35cf1e4b88 --- /dev/null +++ b/tests/ut/ops/test_flash_comm1.py @@ -0,0 +1,64 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + + +import torch +import importlib +from tests.ut.base import TestBase +from unittest.mock import MagicMock, patch + +from vllm.distributed.parallel_state import GroupCoordinator + +from vllm_ascend.ops import sequence_parallel + + +class Test_Flash_Comm1(TestBase): + + @patch('vllm.distributed.tensor_model_parallel_all_gather') + @patch('vllm.distributed.tensor_model_parallel_reduce_scatter') + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_test_flash_comm1(self, mock_TP, + mock_tensor_model_parallel_reduce_scatter, + mock_tensor_model_parallel_all_gather): + with patch('vllm.distributed.get_tp_group', + return_value=MagicMock(world_size=4, rank_in_group=0)) as mock_get_tp_group: + num_tokens = 9 + hidden_size = 128 + tp_size = 4 + hidden_states = torch.randn(num_tokens, hidden_size) + + mock_tp_group = mock_get_tp_group.return_value + assert mock_tp_group.world_size == 4 # 手动断言属性存在 + assert mock_tp_group.rank_in_group == 0 + + lengths_sum_unpadding = hidden_states.shape[0] + lengths_sum_padding = ((lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size + padding_flag = True + pad_size = lengths_sum_padding - lengths_sum_unpadding + importlib.reload(sequence_parallel) + _metadata_for_padding = sequence_parallel.MetadataForPadding(lengths_sum_unpadding=lengths_sum_unpadding, + lengths_sum_padding=lengths_sum_padding, + padding_flag=padding_flag, + pad_size=pad_size, + not_dummy_and_is_prefill=True) + + mock_tensor_model_parallel_reduce_scatter.return_value = torch.randn(lengths_sum_padding // tp_size, hidden_size) + mock_tensor_model_parallel_all_gather.return_value = torch.randn(lengths_sum_padding, hidden_size) + + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states) + output = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + self.assertEqual(output.shape, (num_tokens, hidden_size)) \ No newline at end of file diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7d7f488f47..76fc9148ea 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -169,7 +169,8 @@ def build(self, max_query_len, common_prefix_len, max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False): + with_prefill_across_dp: bool = False, + is_only_prefill:bool = False): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) @@ -207,7 +208,8 @@ def build(self, attn_mask=attn_mask, attn_state=attn_state, max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp) + with_prefill_across_dp=with_prefill_across_dp, + is_only_prefill=is_only_prefill) return attn_metadata diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 5ea6aa9b3e..2c0808e82c 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -133,6 +133,8 @@ "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": lambda: bool( int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))), + "VLLM_ENABLE_SP": + lambda: bool(int(os.getenv("VLLM_ENABLE_SP", '0'))) } # end-env-vars-definition diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index cae779ce2f..5139eaed3a 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -51,8 +51,12 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", - "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + "vllm_ascend.models.qwen3_moe:AscendQwen3MoeForCausalLM") ModelRegistry.register_model( "PanguProMoEForCausalLM", - "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") \ No newline at end of file + "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") + + ModelRegistry.register_model( + "Qwen2ForCausalLM", + "vllm_ascend.models.qwen2:AscendQwen2ForCausalLM") \ No newline at end of file diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py new file mode 100644 index 0000000000..182a523213 --- /dev/null +++ b/vllm_ascend/models/qwen2.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen3_moe.py +# This file is a part of the vllm-ascend project. + +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix +from vllm.model_executor.models.qwen2 import (Qwen2ForCausalLM, + Qwen2Model, + Qwen2DecoderLayer) + +from vllm_ascend.ops.sequence_parallel import init_metadata_for_sp, MetadataForPadding +import vllm_ascend.envs as envs_ascend + + +class AscendQwen2DecoderLayer(Qwen2DecoderLayer): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + + # To prevent precision issues during the decoder phase when only prefilling enables SP + if not envs_ascend.VLLM_ENABLE_SP: + self.self_attn.o_proj.reduce_results = True + self.mlp.down_proj.reduce_results = True + else: + self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill + self.mlp.down_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill + + if residual is None: + residual = hidden_states + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + residual = _metadata_for_padding.padding_slice(residual) + + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + hidden_states = self.mlp(hidden_states) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states) + + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class AscendQwen2Model(Qwen2Model): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = AscendQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + _metadata_for_padding + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + return hidden_states + + +class AscendQwen2ForCausalLM(Qwen2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = AscendQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + _metadata_for_padding = init_metadata_for_sp(input_ids) + + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, _metadata_for_padding) + return hidden_states \ No newline at end of file diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7a..18c475ea8c 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,10 +16,290 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from typing import Optional, Union -class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): +import torch +from torch import nn + +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.distributed import (get_dp_group, + get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear + +from vllm.model_executor.layers.quantization import QuantizationConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.utils import ( extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeModel, + Qwen3MoeAttention, + Qwen3MoeMLP, + Qwen3MoeForCausalLM) +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.sequence_parallel import init_metadata_for_sp, MetadataForPadding +import vllm_ascend.envs as envs_ascend + + +class AscendQwen3MoeSparseMoeBlock(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None,) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + # is_prefill = attn_metadata.num_prefills > 0 is_prefill or + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = attn_metadata.with_prefill_across_dp + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + enable_sp=_metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill, + ) + + return hidden_states + + +class AscendQwen3MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = AscendQwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> torch.Tensor: + + # To prevent precision issues during the decoder phase when only prefilling enables SP + if not envs_ascend.VLLM_ENABLE_SP: + self.self_attn.o_proj.reduce_results = True + else: + self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill + + # Self Attention + if residual is None: + residual = hidden_states + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + residual = _metadata_for_padding.padding_slice(residual) + + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp(hidden_states, _metadata_for_padding=_metadata_for_padding) + + return hidden_states, residual + + +@support_torch_compile +class AscendQwen3MoeModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: AscendQwen3MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual, _metadata_for_padding=_metadata_for_padding) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned(hidden_states) + + return hidden_states + + +class AscendQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -33,3 +313,33 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = AscendQwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + _metadata_for_padding = init_metadata_for_sp(input_ids) + + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, _metadata_for_padding) + return hidden_states diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 1221d8984d..3484220611 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1284,7 +1284,8 @@ def forward(self, top_k: Optional[int] = None, shared_experts: Optional[Any] = None, gate=None, - replace_allreduce: bool = False): + replace_allreduce: bool = False, + enable_sp: bool = False): assert self.quant_method is not None @@ -1297,14 +1298,16 @@ def forward(self, is_deepseek_v3_r1 = self.global_num_experts == 256 fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill, is_deepseek_v3_r1) + is_prefill, is_deepseek_v3_r1, enable_sp=enable_sp) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if (tp_size > 1 and fused_moe_state not in [ + if enable_sp: + pass + elif (tp_size > 1 and fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast ] and not replace_allreduce): @@ -1323,7 +1326,7 @@ def forward(self, hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] - if self.dp_size > 1: + if self.dp_size > 1 and not enable_sp: if fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled: @@ -1384,7 +1387,9 @@ def forward(self, if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if (tp_size > 1 and fused_moe_state not in [ + if enable_sp: + final_hidden_states = e_hidden_states + elif (tp_size > 1 and fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast ] and not replace_allreduce): diff --git a/vllm_ascend/ops/sequence_parallel.py b/vllm_ascend/ops/sequence_parallel.py new file mode 100644 index 0000000000..19086461b4 --- /dev/null +++ b/vllm_ascend/ops/sequence_parallel.py @@ -0,0 +1,88 @@ +import torch +from torch.nn import functional as F + +from vllm.distributed import (get_tensor_model_parallel_world_size, + get_tp_group, + tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context + +import vllm_ascend.envs as envs + + +class MetadataForPadding: + def __init__(self, padding_flag=False, lengths_sum_padding=0, lengths_sum_unpadding=0, pad_size=0, not_dummy_and_is_prefill=False): + self.padding_flag = padding_flag + self.not_dummy_and_is_prefill = not_dummy_and_is_prefill + + self.lengths_sum_padding = lengths_sum_padding + self.lengths_sum_unpadding = lengths_sum_unpadding + self.pad_size = pad_size + + self.tp_size = get_tp_group().world_size + self.tp_rank_in_group = get_tp_group().rank_in_group + + assert self.lengths_sum_padding % self.tp_size == 0 + self.slice_size = self.lengths_sum_padding // self.tp_size + + def padding_aligned_reduce_scatter(self, data: torch.Tensor) -> torch.Tensor: + if self.padding_flag: + pad_size = self.pad_size + padded_data = F.pad(data, (0, 0, 0, pad_size)) + else: + padded_data = data + + padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(padded_data, 0) + + return padded_data_reduce_scatter + + def allgather_unpadding_aligned(self, padded_data: torch.Tensor) -> torch.Tensor: + padded_data_allgather = tensor_model_parallel_all_gather(padded_data, 0) + if self.padding_flag: + lengths_sum_unpadding= self.lengths_sum_unpadding + unpadding_data = padded_data_allgather[:lengths_sum_unpadding] + else: + unpadding_data = padded_data_allgather + + return unpadding_data + + def padding_slice(self, data: torch.Tensor) -> torch.Tensor: + + padded_data = F.pad(data, (0, 0, 0, self.pad_size)) + start = self.tp_rank_in_group * self.slice_size + end = start + self.slice_size + slice_data = padded_data[start:end] + + return slice_data + + +def init_metadata_for_sp(input_ids): + if not envs.VLLM_ENABLE_SP: + return MetadataForPadding(padding_flag=False, not_dummy_and_is_prefill=False) + is_perifll = 0 + attn_metadata = get_forward_context().attn_metadata + tp_size = get_tensor_model_parallel_world_size() + global _metadata_for_padding + if attn_metadata is not None: + if hasattr(attn_metadata, 'is_only_prefill') and attn_metadata.is_only_prefill: + is_perifll = 1 + if hasattr(attn_metadata, 'num_prefills') and attn_metadata.num_prefills > 0: + is_perifll = 1 + + if is_perifll: + lengths_sum_unpadding = input_ids.shape[0] + lengths_sum_padding = ((lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size + if lengths_sum_unpadding == lengths_sum_padding: + padding_flag = False + else: + padding_flag = True + pad_size = lengths_sum_padding - lengths_sum_unpadding + _metadata_for_padding = MetadataForPadding(lengths_sum_unpadding=lengths_sum_unpadding, + lengths_sum_padding=lengths_sum_padding, + padding_flag=padding_flag, + pad_size=pad_size, + not_dummy_and_is_prefill=True) + + return _metadata_for_padding + + return MetadataForPadding(padding_flag=False, not_dummy_and_is_prefill=False) \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 634e13cb9e..223d7cf50e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -437,6 +437,7 @@ class FusedMoEState(Enum): MC2 = 2 AllGatherEP = 3 NaiveMulticast = 4 + NO_OP = 5 # TODO(ttanzhiqiang): rm_router_logits @@ -471,10 +472,12 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): # TODO(zzzzwwjj): add soc_version to choose branch def get_fused_moe_state(ep_size: int, with_prefill: bool, - is_deepseek_v3_r1: bool): + is_deepseek_v3_r1: bool, enable_sp=False): # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep # only supports deepseek v3/r1 - if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + if enable_sp: + return FusedMoEState.NO_OP + elif (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 and is_deepseek_v3_r1): return FusedMoEState.AllGatherEP elif ep_size == 1: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index eabcdbcc19..0656987d71 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1009,6 +1009,9 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + is_only_prefill = np.all(num_valid_tokens != 1) + extra_builder_kwargs['is_only_prefill'] = is_only_prefill + if self.dp_size > 1: max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( total_num_scheduled_tokens, with_prefill)