Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())

dispatcher_kwargs = {
"num_experts": e,
Expand All @@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)

Expand Down Expand Up @@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())

dispatcher_kwargs = {
"num_experts": e,
Expand All @@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
Expand Down Expand Up @@ -295,7 +281,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)

topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
Expand All @@ -316,7 +302,6 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)

gc.collect()
torch.npu.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,

x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
Expand Down
2 changes: 0 additions & 2 deletions tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
row_idx = torch.arange(4)

# Make sure tensors are contiguous and have correct strides
hidden_states = hidden_states.contiguous()
Expand All @@ -216,7 +215,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
activation="silu")

# Verify result shape
Expand Down
17 changes: 4 additions & 13 deletions tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def setUp(self):

kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)

def tearDown(self):
self.mc2_group_patch.stop()
Expand Down Expand Up @@ -95,7 +94,7 @@ def test_token_permutation_dispatch(self):
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, expert_map)
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
Expand All @@ -116,7 +115,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
Expand Down Expand Up @@ -180,7 +178,6 @@ def setUp(self):
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.row_idx = torch.arange(10, dtype=torch.int32)
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
Expand All @@ -197,7 +194,7 @@ def test_token_dispatch_without_expert_map(self):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])

results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)

# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
Expand All @@ -212,7 +209,7 @@ def test_token_dispatch_with_expert_map(self):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])

results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)

# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
Expand All @@ -236,7 +233,7 @@ def test_token_dispatch_without_quant(self):

results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
None)

self.assertEqual(results["group_list_type"], 1)

Expand All @@ -257,7 +254,6 @@ def test_token_dispatch_with_quant(self):
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
self.row_idx,
None,
with_quant=True)

Expand Down Expand Up @@ -399,7 +395,6 @@ def setUp(self):
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)

def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
Expand All @@ -414,7 +409,6 @@ def test_token_dispatch(self):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)

self.assertIsNotNone(result["hidden_states"])
Expand Down Expand Up @@ -461,7 +455,6 @@ def test_token_dispatch_with_quant(self):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)

Expand Down Expand Up @@ -490,7 +483,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)

Expand All @@ -513,7 +505,6 @@ def test_token_dispatch_with_log2phy(self):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)

Expand Down
73 changes: 43 additions & 30 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,16 @@ def setUp(self):

self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)

@patch('torch_npu.npu_moe_gating_top_k_softmax')
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))

@patch('torch_npu.npu_moe_gating_top_k')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -730,25 +738,27 @@ def test_softmax_scoring(self, mock_topk):
-1).permute(1,
0).contiguous())

weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))

def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""

weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
Expand All @@ -761,7 +771,8 @@ def test_invalid_scoring_func(self):
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func")
scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)

@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
Expand All @@ -771,13 +782,15 @@ def test_grouped_topk(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)

mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
Expand All @@ -791,15 +804,16 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
self.num_experts)

e_score_correction_bias = torch.randn(self.num_experts)
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)

mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
Expand All @@ -814,7 +828,7 @@ def test_custom_routing_function(self):
self.top_k,
dtype=torch.int32))

weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -827,7 +841,7 @@ def test_custom_routing_function(self):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)

@patch('torch_npu.npu_moe_gating_top_k_softmax')
@patch('torch_npu.npu_moe_gating_top_k')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -841,7 +855,7 @@ def test_renormalize(self, mock_topk):
-1).permute(1,
0).contiguous())

weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -853,7 +867,7 @@ def test_renormalize(self, mock_topk):
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))

@patch('torch_npu.npu_moe_gating_top_k_softmax')
@patch('torch_npu.npu_moe_gating_top_k')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -867,7 +881,7 @@ def test_output_dtypes(self, mock_topk):
-1).permute(1,
0).contiguous())

weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
Expand All @@ -876,7 +890,6 @@ def test_output_dtypes(self, mock_topk):
)

self.assertEqual(weights.dtype, self.hidden_states.dtype)
self.assertEqual(ids.dtype, torch.int32)


class TestNativeGroupedTopkPartialMock(TestBase):
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:

topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand All @@ -101,7 +101,6 @@ def forward_oot(
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)

Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def apply(
**kwargs,
) -> torch.Tensor:

topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand All @@ -132,7 +132,6 @@ def apply(
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
Expand Down
Loading
Loading