-
Notifications
You must be signed in to change notification settings - Fork 267
[V0.9.1] add support for flashcomm2 in qwen3 #1726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
vllm_ascend/envs.py
Outdated
@@ -143,6 +143,8 @@ | |||
# Batch MC2 in prefill: The number of tokens in each batch | |||
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE": | |||
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")), | |||
"VLLM_ENABLE_FC": | |||
lambda: int(os.getenv("VLLM_ENABLE_FC", 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this default value is better to be '0'?
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
vllm_ascend/models/qwen3.py
Outdated
output = torch.empty(attn_output.shape, | ||
dtype=attn_output.dtype, | ||
device=attn_output.device) | ||
dist.all_to_all_single(output, attn_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the comm_group of this operation? if you do this all_to_all
in world_size, what if we have pipeline parallel ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And what's the purpose of this all_to_all
here? Why introduce this additional communication here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And what's the purpose of this
all_to_all
here? Why introduce this additional communication here?
Linear+Allreduce is replaced with All2All+Linear+AllGather in Flashcomm2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point for this change? dose it save the communication or computation? You just replicate two copy of o_proj on each device, which will increase the both bindwidth pressure and memory allocation. And the all2all dose not actually reduce the input data amount right? You just switch the data by all the tp rank. So what's the point of those change, does it really brings performance boost?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point for this change? dose it save the communication or computation? You just replicate two copy of o_proj on each device, which will increase the both bindwidth pressure and memory allocation. And the all2all dose not actually reduce the input data amount right? You just switch the data by all the tp rank. So what's the point of those change, does it really brings performance boost?
This change saves communication because input data size of all2all is 1/tp_size of allreduce. We can get performance benefits as long as the benefits of communication cover the increase of bindwidth pressure of linear.
refer this link for more details: https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm2.md
Do you have any performance statistics for this PR? |
Signed-off-by: David9857 <985700846@qq.com>
Here's the comparison of TTFT time when input_len=3000 and max_concurrency=20: |
Looks good, thanks for the explain !
|
What this PR does / why we need it?
Support FlashComm v2 in qwen3, which can reduce latency at prefill stage. set VLLM_ENABLE_FC=1 and use eager mode to enbale this feature.
Note: Enabling FlashComm in decoding stage may cause increased latency, so it is recommended to use disaggregated prefilling and enbale this feature in the prefill instance only!!!
Does this PR introduce any user-facing change?
NA
How was this patch tested?
NA