Skip to content

Commit 84ea6ff

Browse files
committed
[bugfix][torchair] fix recompiles and multistream_moe problems with torchair graph mode
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 0df059f commit 84ea6ff

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,14 +1121,14 @@ def forward(self,
11211121
mc2_mask = forward_context.mc2_mask
11221122
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
11231123
quantized_x_for_share, dynamic_scale_for_share = None, None
1124-
from vllm_ascend.quantization.w8a8_dynamic import \
1125-
AscendW8A8DynamicFusedMoEMethod
1124+
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
1125+
TorchairAscendW8A8DynamicFusedMoEMethod
11261126
if self.enable_multistream_moe:
11271127
if not self.rm_router_logits:
11281128
router_logits, _ = gate(hidden_states)
11291129
if hasattr(self.quant_method, "quant_method") and \
11301130
isinstance(self.quant_method.quant_method,
1131-
AscendW8A8DynamicFusedMoEMethod
1131+
TorchairAscendW8A8DynamicFusedMoEMethod
11321132
) and fused_moe_state == FusedMoEState.MC2:
11331133
with npu_stream_switch("moe_secondary", 0):
11341134
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
@@ -1154,30 +1154,32 @@ def forward(self,
11541154
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
11551155
FusedMoEState.NaiveMulticast
11561156
] and not replace_allreduce):
1157-
if fused_moe_state in {FusedMoEState.MC2}:
1158-
padding_size = forward_context.padded_num_tokens
1159-
else:
1160-
# TODO: Determine if we can remove the padding
1161-
padding_size = tp_size
1162-
if num_tokens < padding_size and not self.enable_shared_expert_dp:
1163-
hidden_states = nn.functional.pad(
1164-
hidden_states, (0, 0, 0, padding_size - num_tokens))
1165-
router_logits = nn.functional.pad(
1166-
router_logits, (0, 0, 0, padding_size - num_tokens))
11671157
if tp_size > 1:
11681158
tp_rank = get_tensor_model_parallel_rank()
1169-
if not self.enable_shared_expert_dp:
1170-
chunk_hidden_states = torch.tensor_split(hidden_states,
1171-
tp_size,
1172-
dim=0)
1173-
chunk_router_logits = torch.tensor_split(router_logits,
1174-
tp_size,
1175-
dim=0)
1176-
hidden_states = chunk_hidden_states[tp_rank]
1177-
router_logits = chunk_router_logits[tp_rank]
1178-
11791159
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
11801160
mc2_mask = chunk_mc2_mask[tp_rank]
1161+
if not replace_allreduce:
1162+
if fused_moe_state in {FusedMoEState.MC2}:
1163+
padding_size = forward_context.padded_num_tokens
1164+
else:
1165+
# TODO: Determine if we can remove the padding
1166+
padding_size = tp_size
1167+
if num_tokens < padding_size and not self.enable_shared_expert_dp:
1168+
hidden_states = nn.functional.pad(
1169+
hidden_states, (0, 0, 0, padding_size - num_tokens))
1170+
router_logits = nn.functional.pad(
1171+
router_logits, (0, 0, 0, padding_size - num_tokens))
1172+
if tp_size > 1:
1173+
tp_rank = get_tensor_model_parallel_rank()
1174+
if not self.enable_shared_expert_dp:
1175+
chunk_hidden_states = torch.tensor_split(hidden_states,
1176+
tp_size,
1177+
dim=0)
1178+
chunk_router_logits = torch.tensor_split(router_logits,
1179+
tp_size,
1180+
dim=0)
1181+
hidden_states = chunk_hidden_states[tp_rank]
1182+
router_logits = chunk_router_logits[tp_rank]
11811183

11821184
if self.dp_size > 1:
11831185
if fused_moe_state == FusedMoEState.AllGather:

vllm_ascend/torchair/torchair_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def __init__(
626626

627627
ascend_config = get_ascend_config()
628628
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
629+
self.running_in_graph = False
629630
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
630631
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
631632

0 commit comments

Comments
 (0)