Skip to content

[V0.9.1] add support for flashcomm2 in qwen3 #1726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025

Conversation

David9857
Copy link
Contributor

@David9857 David9857 commented Jul 10, 2025

What this PR does / why we need it?

Support FlashComm v2 in qwen3, which can reduce latency at prefill stage. set VLLM_ENABLE_FC=1 and use eager mode to enbale this feature.
Note: Enabling FlashComm in decoding stage may cause increased latency, so it is recommended to use disaggregated prefilling and enbale this feature in the prefill instance only!!!

Does this PR introduce any user-facing change?

NA

How was this patch tested?

NA

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@@ -143,6 +143,8 @@
# Batch MC2 in prefill: The number of tokens in each batch
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
"VLLM_ENABLE_FC":
lambda: int(os.getenv("VLLM_ENABLE_FC", 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this default value is better to be '0'?

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

output = torch.empty(attn_output.shape,
dtype=attn_output.dtype,
device=attn_output.device)
dist.all_to_all_single(output, attn_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the comm_group of this operation? if you do this all_to_all in world_size, what if we have pipeline parallel ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what's the purpose of this all_to_all here? Why introduce this additional communication here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what's the purpose of this all_to_all here? Why introduce this additional communication here?

Linear+Allreduce is replaced with All2All+Linear+AllGather in Flashcomm2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point for this change? dose it save the communication or computation? You just replicate two copy of o_proj on each device, which will increase the both bindwidth pressure and memory allocation. And the all2all dose not actually reduce the input data amount right? You just switch the data by all the tp rank. So what's the point of those change, does it really brings performance boost?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point for this change? dose it save the communication or computation? You just replicate two copy of o_proj on each device, which will increase the both bindwidth pressure and memory allocation. And the all2all dose not actually reduce the input data amount right? You just switch the data by all the tp rank. So what's the point of those change, does it really brings performance boost?

This change saves communication because input data size of all2all is 1/tp_size of allreduce. We can get performance benefits as long as the benefits of communication cover the increase of bindwidth pressure of linear.
refer this link for more details: https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/ascend-inference-cluster-flashcomm2.md

@ganyi1996ppo
Copy link
Collaborator

Do you have any performance statistics for this PR?

Signed-off-by: David9857 <985700846@qq.com>
@David9857
Copy link
Contributor Author

Do you have any performance statistics for this PR?

Here's the comparison of TTFT time when input_len=3000 and max_concurrency=20:
origin:
Mean TTFT (ms): 1112.23
Median TTFT (ms): 740.95
P99 TTFT (ms): 3382.67
with flashcomm2
Mean TTFT (ms): 1064.02
Median TTFT (ms): 744.06
P99 TTFT (ms): 3001.43

@ganyi1996ppo
Copy link
Collaborator

Looks good, thanks for the explain !

Here's the comparison of TTFT time when input_len=3000 and max_concurrency=20: origin: Mean TTFT (ms): 1112.23 Median TTFT (ms): 740.95 P99 TTFT (ms): 3382.67 with flashcomm2 Mean TTFT (ms): 1064.02 Median TTFT (ms): 744.06 P99 TTFT (ms): 3001.43

@ganyi1996ppo ganyi1996ppo merged commit 89129a8 into vllm-project:v0.9.1-dev Jul 16, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants