From 756f5cbbb60936273a11c7b53bd644b68529fc43 Mon Sep 17 00:00:00 2001 From: "yangcheng (AJ)" Date: Wed, 9 Jul 2025 11:53:49 +0800 Subject: [PATCH] add moe operation Signed-off-by: yangcheng (AJ) --- tests/e2e/multicard/test_qwen3_moe.py | 72 +++++++++++++++ tests/ut/moe/test_qwen3_moe_block.py | 127 ++++++++++++++++++++++++++ vllm_ascend/models/qwen3_moe.py | 98 ++++++++++++++++++++ 3 files changed, 297 insertions(+) create mode 100644 tests/e2e/multicard/test_qwen3_moe.py create mode 100644 tests/ut/moe/test_qwen3_moe_block.py diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py new file mode 100644 index 0000000000..e05266ee05 --- /dev/null +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/multicard/test_data_parallel.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"}) +def test_qwen3_moe_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "2", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + "--enforce-eager", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode() + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/tests/ut/moe/test_qwen3_moe_block.py b/tests/ut/moe/test_qwen3_moe_block.py new file mode 100644 index 0000000000..87ce1c7a3c --- /dev/null +++ b/tests/ut/moe/test_qwen3_moe_block.py @@ -0,0 +1,127 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from vllm_ascend.models.qwen3_moe import AscendQwen3MoeSparseMoeBlock + + +class TestAscendQwen3MoeSparseMoeBlock(unittest.TestCase): + + def setUp(self): + # Create a mock config + self.mock_config = MagicMock() + self.mock_config.hidden_size = 512 + self.mock_config.num_experts = 8 + self.mock_config.num_experts_per_tok = 2 + self.mock_config.moe_intermediate_size = 1024 + self.mock_config.norm_topk_prob = True + + # Mock all the distributed and environment dependencies + self.patchers = [ + patch('vllm.distributed.get_tensor_model_parallel_world_size', + return_value=1), + patch('vllm_ascend.ascend_config.get_ascend_config', + return_value=MagicMock(torchair_graph_config=MagicMock( + enabled=True, enable_multistream_moe=True))), + patch('vllm.distributed.parallel_state.get_dp_group', + return_value=MagicMock(world_size=1)), + patch('vllm.distributed.get_tp_group', + return_value=MagicMock(device_group=None, rank_in_group=0)), + patch('vllm_ascend.distributed.parallel_state.get_ep_group', + return_value=None), + patch('vllm.forward_context.get_forward_context', + return_value=MagicMock(attn_metadata=None)), + patch('torch.get_default_dtype', return_value=torch.float32) + ] + + for patcher in self.patchers: + patcher.start() + + # Mock the ReplicatedLinear and AscendFusedMoE classes + self.mock_replicated_linear = MagicMock(spec=nn.Linear) + self.mock_fused_moe = MagicMock() + + with patch('vllm.model_executor.layers.linear.ReplicatedLinear', return_value=self.mock_replicated_linear), \ + patch('vllm_ascend.ops.fused_moe.AscendFusedMoE', return_value=self.mock_fused_moe): + + self.block = AscendQwen3MoeSparseMoeBlock(config=self.mock_config, + quant_config=None, + prefix="moe") + + def tearDown(self): + for patcher in self.patchers: + patcher.stop() + + def test_initialization(self): + # Test initialization values + self.assertEqual(self.block.top_k, + self.mock_config.num_experts_per_tok) + self.assertEqual(self.block.params_dtype, torch.float32) + self.assertTrue(self.block.torchair_graph_enabled) + self.assertTrue(self.block.enable_multistream_moe) + + # Check if submodules were created + self.mock_replicated_linear.assert_called_once() + self.mock_fused_moe.assert_called_once() + + def test_forward_with_attn_metadata(self): + # Setup mock return values + mock_router_logits = torch.randn(10, self.mock_config.num_experts) + self.mock_replicated_linear.return_value = (mock_router_logits, None) + + mock_hidden_states = torch.randn(10, self.mock_config.hidden_size) + mock_output = torch.randn(10, self.mock_config.hidden_size) + self.mock_fused_moe.return_value = mock_output + + # Mock attention metadata + mock_attn_metadata = MagicMock() + mock_attn_metadata.with_prefill_across_dp = False + + # Test forward pass + output = self.block(mock_hidden_states, mock_attn_metadata) + + # Verify calls + self.mock_replicated_linear.assert_called_once_with(mock_hidden_states) + self.mock_fused_moe.assert_called_once_with( + hidden_states=mock_hidden_states, + router_logits=mock_router_logits, + is_prefill=False, + top_k=self.mock_config.num_experts_per_tok, + enable_force_load_balance=False, + shared_experts=None) + self.assertTrue(torch.equal(output, mock_output)) + + def test_forward_without_attn_metadata(self): + # Setup mock return values + mock_router_logits = torch.randn(10, self.mock_config.num_experts) + self.mock_replicated_linear.return_value = (mock_router_logits, None) + + mock_hidden_states = torch.randn(10, self.mock_config.hidden_size) + mock_output = torch.randn(10, self.mock_config.hidden_size) + self.mock_fused_moe.return_value = mock_output + + # Test forward pass without attention metadata + output = self.block(mock_hidden_states) + + # Verify calls - should use default values when no metadata + self.mock_replicated_linear.assert_called_once_with(mock_hidden_states) + self.mock_fused_moe.assert_called_once_with( + hidden_states=mock_hidden_states, + router_logits=mock_router_logits, + is_prefill=True, + top_k=self.mock_config.num_experts_per_tok, + enable_force_load_balance=True, + shared_experts=None) + self.assertTrue(torch.equal(output, mock_output)) + + def test_tp_size_greater_than_experts(self): + # Test the validation for TP size vs number of experts + with patch('vllm.distributed.get_tensor_model_parallel_world_size', + return_value=10): + with self.assertRaises(ValueError) as context: + self.block = AscendQwen3MoeSparseMoeBlock( + config=self.mock_config, quant_config=None, prefix="moe") + self.assertIn("Tensor parallel size 10 is greater than", + str(context.exception)) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7a..0e02965b66 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,8 +16,24 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. +from typing import Optional + +import torch +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.ops.fused_moe import AscendFusedMoE + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { @@ -33,3 +49,85 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + + +class AscendQwen3MoeSparseMoeBlock(nn.Module): + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = attn_metadata.with_prefill_across_dp + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None) + + return hidden_states + + +vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock