-
Notifications
You must be signed in to change notification settings - Fork 458
[bugfix][torchair] fix multistream_moe problems in torchair graph mode #2681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bugfix][torchair] fix multistream_moe problems in torchair graph mode #2681
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses recompilation and multistream MoE issues in torchair graph mode. The changes involve refactoring import paths for quantization methods, adjusting logic for tensor padding and splitting in torchair_fused_moe
, and adding a running_in_graph
flag in torchair_mla
. My review focuses on improving the clarity and efficiency of the modified logic in torchair_fused_moe.py
. I've identified a redundant conditional check and a duplicate function call, and provided a suggestion to refactor the code for better readability and performance.
if tp_size > 1: | ||
tp_rank = get_tensor_model_parallel_rank() | ||
if not self.enable_shared_expert_dp: | ||
chunk_hidden_states = torch.tensor_split(hidden_states, | ||
tp_size, | ||
dim=0) | ||
chunk_router_logits = torch.tensor_split(router_logits, | ||
tp_size, | ||
dim=0) | ||
hidden_states = chunk_hidden_states[tp_rank] | ||
router_logits = chunk_router_logits[tp_rank] | ||
|
||
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) | ||
mc2_mask = chunk_mc2_mask[tp_rank] | ||
if not replace_allreduce: | ||
if fused_moe_state in {FusedMoEState.MC2}: | ||
padding_size = forward_context.padded_num_tokens | ||
else: | ||
# TODO: Determine if we can remove the padding | ||
padding_size = tp_size | ||
if num_tokens < padding_size and not self.enable_shared_expert_dp: | ||
hidden_states = nn.functional.pad( | ||
hidden_states, (0, 0, 0, padding_size - num_tokens)) | ||
router_logits = nn.functional.pad( | ||
router_logits, (0, 0, 0, padding_size - num_tokens)) | ||
if tp_size > 1: | ||
tp_rank = get_tensor_model_parallel_rank() | ||
if not self.enable_shared_expert_dp: | ||
chunk_hidden_states = torch.tensor_split(hidden_states, | ||
tp_size, | ||
dim=0) | ||
chunk_router_logits = torch.tensor_split(router_logits, | ||
tp_size, | ||
dim=0) | ||
hidden_states = chunk_hidden_states[tp_rank] | ||
router_logits = chunk_router_logits[tp_rank] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in this block can be simplified to improve readability and remove redundant operations. Specifically:
- The
if not replace_allreduce:
check at line 1161 is redundant because it's already part of the outer condition at line 1156. get_tensor_model_parallel_rank()
is called twice iftp_size > 1
, which is inefficient. It should be called only once and the result should be reused.
I suggest refactoring this block to remove the nested if
and the duplicate function call.
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
a2f7755
to
84ea6ff
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (56.25%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #2681 +/- ##
==========================================
- Coverage 73.49% 73.49% -0.01%
==========================================
Files 151 151
Lines 21927 21931 +4
==========================================
+ Hits 16116 16118 +2
- Misses 5811 5813 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
398f9d5
to
074c9a8
Compare
@linfeng-yuan This PR seems important. How's the progress? |
…orchair graph mode Signed-off-by: linfeng-yuan <1102311262@qq.com>
074c9a8
to
a7a37d6
Compare
I've rebased the code and CI is passed. Currently multistream_moe with |
We've found this problem too. We can discuss solutions if you needed. |
@linfeng-yuan ready to merge? cc @Yikun @wangxiyuan |
What this PR does / why we need it?
This pr fixes two problems while
multistream_moe
enabled in torchair graph mode:TorchairAscendW8A8DynamicFusedMoEMethod
instead of incorrectAscendW8A8DynamicFusedMoEMethod
replace_allreduce
is True or False in forward function ofTorchairAscendFusedMoE
Does this PR introduce any user-facing change?
No.
How was this patch tested?
e2e vllm serving with
multistream_moe
(DP32TP1EP32)