-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Attention] add DCP support for FLASH_ATTN_MLA backend #24453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] add DCP support for FLASH_ATTN_MLA backend #24453
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This pull request has merge conflicts that must be resolved before it can be |
vllm/v1/worker/gpu_model_runner.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the assert cannot be removed yet; some modifications to the MLA attention kernels are necessary. This is because, each query token may have a different seqlen_k
on different DCP ranks.
For example, with dcp=2
and query_len=2
, note as AB
, if we treat this as a decode request:
- The KV-cache for key
k_A
is stored on DCP rank 0, andk_B
on DCP rank 1. - On DCP rank 0, both
q_A
andq_B
should haveseqlen_k = 1.
- However, on DCP rank 1,
q_A
should haveseqlen_k =0
, andq_B
should haveseqlen_k = 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You already handle this:
Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 just recorrect seqlens_k
for DCP decode, and it's trivial under query_len=1
. But once query_len>1
, the situation changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 corrects anything thats classified as "decode" by
vllm/vllm/v1/attention/backends/mla/common.py
Lines 666 to 667 in bba1042
split_decodes_and_prefills(common_attn_metadata, | |
decode_threshold=self.reorder_batch_threshold) |
q_len <= reorder_batch_threshold
so I still fail to see the issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when q_len == 1
, num_reqs
== num_tokens
sorry i understand the issue now; we need a special causal mask for q_len > 1
; i.e.
Normal:
k_toks > 0 1 2 3 4 5
q_toks v _____________
0 | 1 1 1
1 | 1 1 1 1
2 | 1 1 1 1 1
3 | 1 1 1 1 1 1
DCP Rank 0:
k_toks > 0 2 4
q_toks v _______
0 | 1 1
1 | 1 1
2 | 1 1 1
3 | 1 1 1
DCP Rank 1:
k_toks > 1 3 5
q_toks v ______
0 | 1
1 | 1 1
2 | 1 1
3 | 1 1 1
Apologies for the oversight on my side :face_palm: because of working on #22789 im not used to thinking about interleaved tokens being distributed because that approach i distributed contiguous blocks of tokens (full pages). Good catch! 🙏
I will add support for this mask in FA3 so we can combine DCP and FA3 (we should do the same for FlashMLA); in the meantime ill make the reorder_batch_threshold == 1
when DCP is turned on 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @MatthewBonanni who might have bandwidth before I do 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson @MatthewBonanni FYI #24864.
we also can separate mla dcp decode into two stage when query_len>1, context_kv use causal=Fasle, query_kv use caual=True, than we don't need hack a custom mask for all mla backends. I don't know which one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect to have PRs up today with the custom mask, so we'll at least have that option. The WIP flash attention PR is: vllm-project/flash-attention#92
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CC @minosfuture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, thanks! let's discuss on slack
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
8e0c733
to
cd3bafa
Compare
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
…24453) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
…24453) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
…24453) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Fix FA MLA return and add DCP support for FA MLA
Test Plan
lm_eval results
Test Result
passes
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.