Skip to content

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Sep 8, 2025

Purpose

Fix FA MLA return and add DCP support for FA MLA

Test Plan

VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA chg run -g 4 -- pytest tests/distributed/test_context_parallel.py -s

lm_eval results

Test Result

VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA chg run -g 2 --  vllm serve --model="deepseek-ai/DeepSeek-V2-Lite-Chat" --trust-remote-code -tp 2 -dcp 2 --port 3331

lm_eval --model local-completions --model_args "base_url=http://0.0.0.0:3331/v1/completions,model=deepseek-ai/DeepSeek-V2-Lite-Chat,num_concurrent=256" --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6437|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6391|±  |0.0132|
VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA chg run -g 4 -- pytest tests/distributed/test_context_parallel.py -s

passes


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
Copy link
Contributor

@youzhedian youzhedian Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the assert cannot be removed yet; some modifications to the MLA attention kernels are necessary. This is because, each query token may have a different seqlen_k on different DCP ranks.
For example, with dcp=2 and query_len=2, note as AB, if we treat this as a decode request:

  • The KV-cache for key k_A is stored on DCP rank 0, and k_B on DCP rank 1.
  • On DCP rank 0, both q_A and q_B should have seqlen_k = 1.
  • However, on DCP rank 1, q_A should have seqlen_k =0, and q_B should have seqlen_k = 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@youzhedian youzhedian Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 just recorrect seqlens_k for DCP decode, and it's trivial under query_len=1. But once query_len>1, the situation changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 corrects anything thats classified as "decode" by

split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
i.e. anything with q_len <= reorder_batch_threshold so I still fail to see the issue?

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when q_len == 1, num_reqs == num_tokens

sorry i understand the issue now; we need a special causal mask for q_len > 1; i.e.

Normal:

k_toks >   0 1 2 3 4 5
q_toks v  _____________
       0 | 1 1 1
       1 | 1 1 1 1
       2 | 1 1 1 1 1
       3 | 1 1 1 1 1 1


DCP Rank 0:

k_toks >   0 2 4
q_toks v  _______
       0 | 1 1
       1 | 1 1
       2 | 1 1 1
       3 | 1 1 1 


DCP Rank 1:

k_toks >   1 3 5
q_toks v   ______
       0 | 1
       1 | 1 1
       2 | 1 1
       3 | 1 1 1

Apologies for the oversight on my side :face_palm: because of working on #22789 im not used to thinking about interleaved tokens being distributed because that approach i distributed contiguous blocks of tokens (full pages). Good catch! 🙏

I will add support for this mask in FA3 so we can combine DCP and FA3 (we should do the same for FlashMLA); in the meantime ill make the reorder_batch_threshold == 1 when DCP is turned on 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @MatthewBonanni who might have bandwidth before I do 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson @MatthewBonanni FYI #24864.

we also can separate mla dcp decode into two stage when query_len>1, context_kv use causal=Fasle, query_kv use caual=True, than we don't need hack a custom mask for all mla backends. I don't know which one is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect to have PRs up today with the custom mask, so we'll at least have that option. The WIP flash attention PR is: vllm-project/flash-attention#92

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, thanks! let's discuss on slack

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson and others added 3 commits September 9, 2025 13:52
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
@youkaichao youkaichao merged commit 0ae43db into vllm-project:main Sep 10, 2025
12 checks passed
@youkaichao youkaichao deleted the lwilkinson/fa-mla-dcp branch September 10, 2025 09:19
@youkaichao youkaichao changed the title [Attention] Fix FA MLA and add DCP support [Attention] add DCP support for FLASH_ATTN_MLA backend Sep 10, 2025
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…24453)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…24453)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…24453)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants