Skip to content

Commit 2952369

Browse files
committed
fix lint
Signed-off-by: hust17yixuan <303660421@qq.com>
1 parent 0fc6dcf commit 2952369

File tree

3 files changed

+52
-49
lines changed

3 files changed

+52
-49
lines changed

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
26-
# from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
27-
# AscendUnquantizedFusedMoEMethod)
28-
from vllm_ascend.torchair.ops.torchair_fused_moe import (TorchairAscendFusedMoE,
29-
TorchairAscendUnquantizedFusedMoEMethod)
26+
from vllm_ascend.torchair.ops.torchair_fused_moe import (
27+
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
3028
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
3129

3230
adapt_patch(True)
@@ -57,33 +55,33 @@ def mock_dist_env(mocker: MockerFixture):
5755

5856
with patch('torch.distributed.get_rank', return_value=0), \
5957
patch('torch.distributed.get_world_size', return_value=4), \
60-
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
61-
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
62-
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
58+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
59+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
60+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
6361
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
64-
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
62+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
6563
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
6664
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
6765
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
68-
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
66+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
6967
return_value=torch.randn(5, 32)), \
70-
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
68+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter',
7169
return_value=torch.randn(5, 32)), \
7270
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
7371
return_value=mock_dp_and_tp_group(mocker)), \
74-
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
72+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
7573
return_value=MagicMock(
7674
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
7775
expert_map_path=None
7876
)), \
79-
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
77+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
8078
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
81-
patch('vllm_ascend.ops.fused_moe.get_forward_context',
79+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
8280
return_value=MagicMock(
8381
max_tokens_across_dp=10,
8482
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
8583
)), \
86-
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
84+
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
8785
return_value=MagicMock(
8886
parallel_config=MagicMock(tensor_parallel_size=2),
8987
scheduler_config=MagicMock(max_num_seqs=4),
@@ -196,7 +194,7 @@ def apply(self, hidden_states: torch.Tensor,
196194
pass
197195

198196

199-
class TestAscendFusedMoe:
197+
class TestTorchairAscendFusedMoe:
200198

201199
def test_init_no_quant(self, mock_dist_env, default_moe_config):
202200
layer = TorchairAscendFusedMoE(**default_moe_config)
@@ -233,7 +231,7 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
233231
mock_quant_config.get_quant_method.return_value = mock_quant_method
234232

235233
moe = TorchairAscendFusedMoE(**default_moe_config,
236-
quant_config=mock_quant_config)
234+
quant_config=mock_quant_config)
237235

238236
assert moe.quant_method is not None
239237
assert moe.quant_method == mock_quant_method
@@ -266,7 +264,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
266264
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
267265
dtype=torch.bool),
268266
padded_num_tokens=num_tokens)
269-
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
267+
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
270268
return_value=forward_context):
271269
output = moe.forward(inputs,
272270
router_logits,
@@ -299,7 +297,7 @@ def test_forward_ms_fused_moe_comp(self, mock_dist_env,
299297
assert output.shape == (5, 32)
300298

301299

302-
class TestAscendUnquantizedFusedMoEMethod:
300+
class TestTorchairAscendUnquantizedFusedMoEMethod:
303301

304302
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
305303
layer = MagicMock()
@@ -328,7 +326,7 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
328326
is_deepseek_v3_r1 = global_num_experts == 256
329327
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
330328
ep_size, is_prefill, is_deepseek_v3_r1))
331-
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
329+
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
332330
return_value=forward_context):
333331
moe_method.ep_size = ep_size
334332
x = torch.randn(8, 2, 2)
@@ -363,10 +361,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
363361
is_prefill = False
364362
forward_context = MagicMock(
365363
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
366-
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
364+
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.MOE_ALL2ALL_BUFFER",
367365
alltoall_buffer), \
368-
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
369-
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
366+
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
367+
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
370368
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
371369
moe_method.ep_size = ep_size
372370
x = torch.randn(8, 2, 2)

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@
7070
from vllm.sequence import IntermediateTensors
7171

7272
from vllm_ascend.ascend_config import get_ascend_config
73-
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
7473
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7574
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
75+
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
7676
from vllm_ascend.utils import dispose_tensor, npu_prefetch
7777

7878

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@
5757
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
5858

5959

60-
def torchair_process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
61-
max_row_per_ep_rank: int, num_tokens: int,
62-
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
60+
def torchair_process_topk_ids(topk_ids: torch.Tensor, expert_num: int,
61+
ep_size: int, max_row_per_ep_rank: int,
62+
num_tokens: int,
63+
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
6364
original_total_elements = num_tokens * top_k
6465
device = topk_ids.device
6566
original_dtype = topk_ids.dtype
@@ -538,10 +539,10 @@ def torchair_fused_experts_with_all2all_buffer(
538539
group_list_type = 0
539540

540541
hidden_states = torchair_apply_mlp(hidden_states,
541-
w1,
542-
w2,
543-
expert_tokens,
544-
group_list_type=group_list_type)
542+
w1,
543+
w2,
544+
expert_tokens,
545+
group_list_type=group_list_type)
545546

546547
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
547548
hidden_states = hidden_states[resorted_idx]
@@ -686,7 +687,8 @@ def torchair_fused_experts_with_all2allv(
686687
tokens_per_expert) = (token_dispatcher.token_permutation(
687688
hidden_states, probs, routing_map))
688689

689-
expert_output = torchair_apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
690+
expert_output = torchair_apply_mlp(dispatched_input, w1, w2,
691+
tokens_per_expert)
690692
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
691693
return output
692694

@@ -960,8 +962,9 @@ def _renormalize_topk_weights(
960962

961963
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
962964
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
963-
topk_weights = torchair_native_grouped_topk(topk_weights, num_expert_group,
964-
topk_group)
965+
topk_weights = torchair_native_grouped_topk(topk_weights,
966+
num_expert_group,
967+
topk_group)
965968
# TODO bfloat16 is not supported in torch.topk with ge graph.
966969
if e_score_correction_bias is not None:
967970
topk_ids = torch.topk(topk_weights.to(torch.float32),
@@ -1111,12 +1114,12 @@ def apply(
11111114
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
11121115
]:
11131116
return torchair_fused_experts(hidden_states=x,
1114-
w1=layer.w13_weight,
1115-
w2=layer.w2_weight,
1116-
topk_weights=topk_weights,
1117-
topk_ids=topk_ids,
1118-
top_k=top_k,
1119-
expert_map=expert_map)
1117+
w1=layer.w13_weight,
1118+
w2=layer.w2_weight,
1119+
topk_weights=topk_weights,
1120+
topk_ids=topk_ids,
1121+
top_k=top_k,
1122+
expert_map=expert_map)
11201123
elif MOE_ALL2ALL_BUFFER:
11211124
return torchair_fused_experts_with_all2all_buffer(
11221125
hidden_states=x,
@@ -1140,14 +1143,15 @@ def apply(
11401143
w2=layer.w2_weight,
11411144
)
11421145
else:
1143-
return torchair_fused_experts_with_all2all(hidden_states=x,
1144-
w1=layer.w13_weight,
1145-
w2=layer.w2_weight,
1146-
topk_weights=topk_weights,
1147-
topk_ids=topk_ids,
1148-
top_k=top_k,
1149-
expert_map=expert_map,
1150-
ep_group=get_ep_group())
1146+
return torchair_fused_experts_with_all2all(
1147+
hidden_states=x,
1148+
w1=layer.w13_weight,
1149+
w2=layer.w2_weight,
1150+
topk_weights=topk_weights,
1151+
topk_ids=topk_ids,
1152+
top_k=top_k,
1153+
expert_map=expert_map,
1154+
ep_group=get_ep_group())
11511155

11521156

11531157
class TorchairAscendFusedMoE(FusedMoE):
@@ -1284,7 +1288,8 @@ def __init__(
12841288
quant_config=quant_config)
12851289

12861290
if quant_config is None:
1287-
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(self.moe)
1291+
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
1292+
self.moe)
12881293
else:
12891294
self.quant_method = quant_config.get_quant_method(self, prefix)
12901295

@@ -1563,4 +1568,4 @@ def _forward_ms_fused_moe_comp(
15631568
enable_force_load_balance=enable_force_load_balance,
15641569
)
15651570

1566-
return hidden_states
1571+
return hidden_states

0 commit comments

Comments
 (0)