Skip to content

Commit a63e381

Browse files
author
yangcheng (AJ)
committed
add moe operation
Signed-off-by: yangcheng (AJ) <y00806874@china.huawei.com>
1 parent cc1588b commit a63e381

File tree

3 files changed

+296
-0
lines changed

3 files changed

+296
-0
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
#
18+
"""
19+
Compare the outputs of vLLM with and without aclgraph.
20+
21+
Run `pytest tests/multicard/test_data_parallel.py`.
22+
"""
23+
24+
import os
25+
import subprocess
26+
import sys
27+
from unittest.mock import patch
28+
29+
import pytest
30+
31+
MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"]
32+
33+
34+
@pytest.mark.parametrize("model", MODELS)
35+
@pytest.mark.parametrize("max_tokens", [32])
36+
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
37+
def test_qwen3_moe_inference(model, max_tokens):
38+
script = "examples/offline_data_parallel.py"
39+
40+
env = os.environ.copy()
41+
42+
cmd = [
43+
sys.executable,
44+
script,
45+
"--model",
46+
model,
47+
"--dp-size",
48+
"2",
49+
"--tp-size",
50+
"2",
51+
"--node-size",
52+
"1",
53+
"--node-rank",
54+
"0",
55+
"--trust-remote-code",
56+
"--enforce-eager",
57+
]
58+
59+
print(f"Running subprocess: {' '.join(cmd)}")
60+
proc = subprocess.run(cmd,
61+
env=env,
62+
stdout=subprocess.PIPE,
63+
stderr=subprocess.STDOUT,
64+
timeout=600)
65+
output = proc.stdout.decode()
66+
67+
print(output)
68+
69+
assert "DP rank 0 needs to process" in output
70+
assert "DP rank 1 needs to process" in output
71+
assert "Generated text:" in output
72+
assert proc.returncode == 0

tests/ut/moe/test_qwen3_moe_block.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import torch
5+
import torch.nn as nn
6+
from vllm_ascend.models.qwen3_moe import AscendQwen3MoeSparseMoeBlock
7+
8+
9+
class TestAscendQwen3MoeSparseMoeBlock(unittest.TestCase):
10+
11+
def setUp(self):
12+
# Create a mock config
13+
self.mock_config = MagicMock()
14+
self.mock_config.hidden_size = 512
15+
self.mock_config.num_experts = 8
16+
self.mock_config.num_experts_per_tok = 2
17+
self.mock_config.moe_intermediate_size = 1024
18+
self.mock_config.norm_topk_prob = True
19+
20+
# Mock all the distributed and environment dependencies
21+
self.patchers = [
22+
patch('vllm.distributed.get_tensor_model_parallel_world_size',
23+
return_value=1),
24+
patch('vllm_ascend.ascend_config.get_ascend_config',
25+
return_value=MagicMock(torchair_graph_config=MagicMock(
26+
enabled=True, enable_multistream_moe=True))),
27+
patch('vllm.distributed.parallel_state.get_dp_group',
28+
return_value=MagicMock(world_size=1)),
29+
patch('vllm.distributed.get_tp_group',
30+
return_value=MagicMock(device_group=None, rank_in_group=0)),
31+
patch('vllm_ascend.distributed.parallel_state.get_ep_group',
32+
return_value=None),
33+
patch('vllm.forward_context.get_forward_context',
34+
return_value=MagicMock(attn_metadata=None)),
35+
patch('torch.get_default_dtype', return_value=torch.float32)
36+
]
37+
38+
for patcher in self.patchers:
39+
patcher.start()
40+
41+
# Mock the ReplicatedLinear and AscendFusedMoE classes
42+
self.mock_replicated_linear = MagicMock(spec=nn.Linear)
43+
self.mock_fused_moe = MagicMock()
44+
45+
with patch('vllm.model_executor.layers.linear.ReplicatedLinear', return_value=self.mock_replicated_linear), \
46+
patch('vllm_ascend.ops.fused_moe.AscendFusedMoE', return_value=self.mock_fused_moe):
47+
48+
self.block = AscendQwen3MoeSparseMoeBlock(config=self.mock_config,
49+
quant_config=None,
50+
prefix="moe")
51+
52+
def tearDown(self):
53+
for patcher in self.patchers:
54+
patcher.stop()
55+
56+
def test_initialization(self):
57+
# Test initialization values
58+
self.assertEqual(self.block.top_k,
59+
self.mock_config.num_experts_per_tok)
60+
self.assertEqual(self.block.params_dtype, torch.float32)
61+
self.assertTrue(self.block.torchair_graph_enabled)
62+
self.assertTrue(self.block.enable_multistream_moe)
63+
64+
# Check if submodules were created
65+
self.mock_replicated_linear.assert_called_once()
66+
self.mock_fused_moe.assert_called_once()
67+
68+
def test_forward_with_attn_metadata(self):
69+
# Setup mock return values
70+
mock_router_logits = torch.randn(10, self.mock_config.num_experts)
71+
self.mock_replicated_linear.return_value = (mock_router_logits, None)
72+
73+
mock_hidden_states = torch.randn(10, self.mock_config.hidden_size)
74+
mock_output = torch.randn(10, self.mock_config.hidden_size)
75+
self.mock_fused_moe.return_value = mock_output
76+
77+
# Mock attention metadata
78+
mock_attn_metadata = MagicMock()
79+
mock_attn_metadata.with_prefill_across_dp = False
80+
81+
# Test forward pass
82+
output = self.block(mock_hidden_states, mock_attn_metadata)
83+
84+
# Verify calls
85+
self.mock_replicated_linear.assert_called_once_with(mock_hidden_states)
86+
self.mock_fused_moe.assert_called_once_with(
87+
hidden_states=mock_hidden_states,
88+
router_logits=mock_router_logits,
89+
is_prefill=False,
90+
top_k=self.mock_config.num_experts_per_tok,
91+
enable_force_load_balance=False,
92+
shared_experts=None)
93+
self.assertTrue(torch.equal(output, mock_output))
94+
95+
def test_forward_without_attn_metadata(self):
96+
# Setup mock return values
97+
mock_router_logits = torch.randn(10, self.mock_config.num_experts)
98+
self.mock_replicated_linear.return_value = (mock_router_logits, None)
99+
100+
mock_hidden_states = torch.randn(10, self.mock_config.hidden_size)
101+
mock_output = torch.randn(10, self.mock_config.hidden_size)
102+
self.mock_fused_moe.return_value = mock_output
103+
104+
# Test forward pass without attention metadata
105+
output = self.block(mock_hidden_states)
106+
107+
# Verify calls - should use default values when no metadata
108+
self.mock_replicated_linear.assert_called_once_with(mock_hidden_states)
109+
self.mock_fused_moe.assert_called_once_with(
110+
hidden_states=mock_hidden_states,
111+
router_logits=mock_router_logits,
112+
is_prefill=True,
113+
top_k=self.mock_config.num_experts_per_tok,
114+
enable_force_load_balance=True,
115+
shared_experts=None)
116+
self.assertTrue(torch.equal(output, mock_output))
117+
118+
def test_tp_size_greater_than_experts(self):
119+
# Test the validation for TP size vs number of experts
120+
with patch('vllm.distributed.get_tensor_model_parallel_world_size',
121+
return_value=10):
122+
with self.assertRaises(ValueError) as context:
123+
self.block = AscendQwen3MoeSparseMoeBlock(
124+
config=self.mock_config, quant_config=None, prefix="moe")
125+
self.assertIn("Tensor parallel size 10 is greater than",
126+
str(context.exception))

vllm_ascend/models/qwen3_moe.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,24 @@
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
1818

19+
from typing import Optional
20+
21+
import torch
22+
import vllm
23+
from torch import nn
24+
from transformers import PretrainedConfig
25+
from vllm.attention import AttentionMetadata
26+
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
27+
from vllm.distributed.parallel_state import get_dp_group
28+
from vllm.forward_context import get_forward_context
29+
from vllm.model_executor.layers.linear import ReplicatedLinear
30+
from vllm.model_executor.layers.quantization import QuantizationConfig
1931
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
2032

33+
from vllm_ascend.ascend_config import get_ascend_config
34+
from vllm_ascend.distributed.parallel_state import get_ep_group
35+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
36+
2137

2238
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
2339
packed_modules_mapping = {
@@ -33,3 +49,85 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
3349
"experts":
3450
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
3551
}
52+
53+
54+
class AscendQwen3MoeSparseMoeBlock(nn.Module):
55+
top_k: int
56+
57+
def __init__(
58+
self,
59+
config: PretrainedConfig,
60+
quant_config: Optional[QuantizationConfig] = None,
61+
prefix: str = "",
62+
):
63+
super().__init__()
64+
self.tp_size = get_tensor_model_parallel_world_size()
65+
if self.tp_size > config.num_experts:
66+
raise ValueError(
67+
f"Tensor parallel size {self.tp_size} is greater than "
68+
f"the number of experts {config.num_experts}.")
69+
70+
ascend_config = get_ascend_config()
71+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
72+
self.enable_multistream_moe = \
73+
ascend_config.torchair_graph_config.enable_multistream_moe
74+
75+
self.gate = ReplicatedLinear(config.hidden_size,
76+
config.num_experts,
77+
bias=False,
78+
quant_config=None,
79+
prefix=f"{prefix}.gate")
80+
81+
self.experts = AscendFusedMoE(
82+
num_experts=config.num_experts,
83+
top_k=config.num_experts_per_tok,
84+
hidden_size=config.hidden_size,
85+
intermediate_size=config.moe_intermediate_size,
86+
reduce_results=False,
87+
renormalize=config.norm_topk_prob,
88+
quant_config=quant_config,
89+
prefix=f"{prefix}.experts")
90+
91+
self.top_k = config.num_experts_per_tok
92+
93+
self.dp_size = get_dp_group().world_size
94+
95+
self.tp_group = get_tp_group().device_group
96+
self.tp_rank = get_tp_group().rank_in_group
97+
self.ep_group = get_ep_group()
98+
99+
self.params_dtype = torch.get_default_dtype()
100+
101+
def forward(
102+
self,
103+
hidden_states: torch.Tensor,
104+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
105+
if attn_metadata is None:
106+
attn_metadata = get_forward_context().attn_metadata
107+
# when profile runs, force experts to load balanced tokens
108+
# to avoid high memory consumption on a single rank.
109+
# TODO: need a better flag to indicate whether in profile run or not.
110+
if attn_metadata is None:
111+
# for profile run
112+
is_prefill = True
113+
enable_force_load_balance = True
114+
else:
115+
enable_force_load_balance = False
116+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
117+
is_prefill = attn_metadata.with_prefill_across_dp
118+
119+
# router_logits: (num_tokens, n_experts)
120+
router_logits, _ = self.gate(hidden_states)
121+
122+
hidden_states = self.experts(
123+
hidden_states=hidden_states,
124+
router_logits=router_logits,
125+
is_prefill=is_prefill,
126+
top_k=self.top_k,
127+
enable_force_load_balance=enable_force_load_balance,
128+
shared_experts=None)
129+
130+
return hidden_states
131+
132+
133+
vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock

0 commit comments

Comments
 (0)