Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 1 addition & 50 deletions tests/ut/ops/test_common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,7 @@
import torch

from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge


class TestFusedExpertsMoGE(TestBase):

def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:

mock_is_310p.return_value = False

mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]

mock_swiglu.side_effect = lambda x: x

hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4

moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()

output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)

self.assertEqual(output.shape, (4, 128))
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE


class TestLoadWeight(TestBase):
Expand Down
31 changes: 23 additions & 8 deletions tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase

from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
Expand Down Expand Up @@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
return weight_data


@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
mock_hf_config.model_type = "llama"

mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config

mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048

mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)


@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock()
Expand All @@ -74,7 +95,7 @@ def mock_finalize(hidden_states, **kwargs):

mock_forward_context_obj = MagicMock(
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
Expand Down Expand Up @@ -104,12 +125,6 @@ def mock_finalize(hidden_states, **kwargs):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
Expand Down Expand Up @@ -501,7 +516,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_get_forward_context):

mock_forward_context = MagicMock()
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_forward_context.moe_comm_type = MoECommType.MC2
mock_get_forward_context.return_value = mock_forward_context

mock_is_310p.return_value = False
Expand Down
28 changes: 24 additions & 4 deletions tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ def setUp(self):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0

@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()

# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
Expand Down Expand Up @@ -64,13 +69,18 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)

@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()

# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
Expand Down Expand Up @@ -104,14 +114,19 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)

@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()

# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
Expand Down Expand Up @@ -140,6 +155,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)

@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
Expand All @@ -148,7 +164,11 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()

# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
Expand Down
11 changes: 10 additions & 1 deletion tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
output_size = 56
group_size = 2

@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config):
get_current_vllm_config, mock_get_ascend_config):
# Mock ascend config
mock_ascend_config = Mock()
mock_ascend_config.dynamic_eplb = False
mock_get_ascend_config.return_value = mock_ascend_config

mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()

Expand Down
19 changes: 10 additions & 9 deletions tests/ut/worker/test_model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest

from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import AscendSocVersion
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner

Expand All @@ -24,21 +25,21 @@
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),

# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition

# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),

# Case 4: A3 SOC
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
Expand Down
15 changes: 13 additions & 2 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ class FusedMoEState(Enum):
All2AllSeq = 5


class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
NAIVE_MULTICAST = 3


# TODO(zzzzwwjj): add soc_version to choose branch
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
Expand Down Expand Up @@ -52,7 +59,7 @@ def set_ascend_forward_context(
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: str = "",
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
Expand All @@ -72,7 +79,11 @@ def set_ascend_forward_context(
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"

from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)

forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if
Expand Down
Loading
Loading