Skip to content

add qwen3-moe operation #1709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/e2e/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Compare the outputs of vLLM with and without aclgraph.

Run `pytest tests/multicard/test_data_parallel.py`.
"""

import os
import subprocess
import sys
from unittest.mock import patch

import pytest

MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
def test_qwen3_moe_inference(model, max_tokens):
script = "examples/offline_data_parallel.py"

env = os.environ.copy()

cmd = [
sys.executable,
script,
"--model",
model,
"--dp-size",
"2",
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--trust-remote-code",
"--enforce-eager",
]

print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600)
output = proc.stdout.decode()

print(output)

assert "DP rank 0 needs to process" in output
assert "DP rank 1 needs to process" in output
assert "Generated text:" in output
assert proc.returncode == 0
127 changes: 127 additions & 0 deletions tests/ut/moe/test_qwen3_moe_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import unittest
from unittest.mock import MagicMock, patch

import torch
import torch.nn as nn

from vllm_ascend.models.qwen3_moe import AscendQwen3MoeSparseMoeBlock


class TestAscendQwen3MoeSparseMoeBlock(unittest.TestCase):

def setUp(self):
# Create a mock config
self.mock_config = MagicMock()
self.mock_config.hidden_size = 512
self.mock_config.num_experts = 8
self.mock_config.num_experts_per_tok = 2
self.mock_config.moe_intermediate_size = 1024
self.mock_config.norm_topk_prob = True

# Mock all the distributed and environment dependencies
self.patchers = [
patch('vllm.distributed.get_tensor_model_parallel_world_size',
return_value=1),
patch('vllm_ascend.ascend_config.get_ascend_config',
return_value=MagicMock(torchair_graph_config=MagicMock(
enabled=True, enable_multistream_moe=True))),
patch('vllm.distributed.parallel_state.get_dp_group',
return_value=MagicMock(world_size=1)),
patch('vllm.distributed.get_tp_group',
return_value=MagicMock(device_group=None, rank_in_group=0)),
patch('vllm_ascend.distributed.parallel_state.get_ep_group',
return_value=None),
patch('vllm.forward_context.get_forward_context',
return_value=MagicMock(attn_metadata=None)),
patch('torch.get_default_dtype', return_value=torch.float32)
]

for patcher in self.patchers:
patcher.start()

# Mock the ReplicatedLinear and AscendFusedMoE classes
self.mock_replicated_linear = MagicMock(spec=nn.Linear)
self.mock_fused_moe = MagicMock()

with patch('vllm.model_executor.layers.linear.ReplicatedLinear', return_value=self.mock_replicated_linear), \
patch('vllm_ascend.ops.fused_moe.AscendFusedMoE', return_value=self.mock_fused_moe):

self.block = AscendQwen3MoeSparseMoeBlock(config=self.mock_config,
quant_config=None,
prefix="moe")

def tearDown(self):
for patcher in self.patchers:
patcher.stop()

def test_initialization(self):
# Test initialization values
self.assertEqual(self.block.top_k,
self.mock_config.num_experts_per_tok)
self.assertEqual(self.block.params_dtype, torch.float32)
self.assertTrue(self.block.torchair_graph_enabled)
self.assertTrue(self.block.enable_multistream_moe)

# Check if submodules were created
self.mock_replicated_linear.assert_called_once()
self.mock_fused_moe.assert_called_once()

def test_forward_with_attn_metadata(self):
# Setup mock return values
mock_router_logits = torch.randn(10, self.mock_config.num_experts)
self.mock_replicated_linear.return_value = (mock_router_logits, None)

mock_hidden_states = torch.randn(10, self.mock_config.hidden_size)
mock_output = torch.randn(10, self.mock_config.hidden_size)
self.mock_fused_moe.return_value = mock_output

# Mock attention metadata
mock_attn_metadata = MagicMock()
mock_attn_metadata.with_prefill_across_dp = False

# Test forward pass
output = self.block(mock_hidden_states, mock_attn_metadata)

# Verify calls
self.mock_replicated_linear.assert_called_once_with(mock_hidden_states)
self.mock_fused_moe.assert_called_once_with(
hidden_states=mock_hidden_states,
router_logits=mock_router_logits,
is_prefill=False,
top_k=self.mock_config.num_experts_per_tok,
enable_force_load_balance=False,
shared_experts=None)
self.assertTrue(torch.equal(output, mock_output))

def test_forward_without_attn_metadata(self):
# Setup mock return values
mock_router_logits = torch.randn(10, self.mock_config.num_experts)
self.mock_replicated_linear.return_value = (mock_router_logits, None)

mock_hidden_states = torch.randn(10, self.mock_config.hidden_size)
mock_output = torch.randn(10, self.mock_config.hidden_size)
self.mock_fused_moe.return_value = mock_output

# Test forward pass without attention metadata
output = self.block(mock_hidden_states)

# Verify calls - should use default values when no metadata
self.mock_replicated_linear.assert_called_once_with(mock_hidden_states)
self.mock_fused_moe.assert_called_once_with(
hidden_states=mock_hidden_states,
router_logits=mock_router_logits,
is_prefill=True,
top_k=self.mock_config.num_experts_per_tok,
enable_force_load_balance=True,
shared_experts=None)
self.assertTrue(torch.equal(output, mock_output))

def test_tp_size_greater_than_experts(self):
# Test the validation for TP size vs number of experts
with patch('vllm.distributed.get_tensor_model_parallel_world_size',
return_value=10):
with self.assertRaises(ValueError) as context:
self.block = AscendQwen3MoeSparseMoeBlock(
config=self.mock_config, quant_config=None, prefix="moe")
self.assertIn("Tensor parallel size 10 is greater than",
str(context.exception))
98 changes: 98 additions & 0 deletions vllm_ascend/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,24 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.

from typing import Optional

import torch
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE


class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
Expand All @@ -33,3 +49,85 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}


class AscendQwen3MoeSparseMoeBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue that this PR aims to address?

top_k: int

def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe

self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")

self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")

self.top_k = config.num_experts_per_tok

self.dp_size = get_dp_group().world_size

self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()

def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = attn_metadata.with_prefill_across_dp

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None)

return hidden_states


vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock
Loading