Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
Expand Down Expand Up @@ -98,6 +99,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold

def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
Expand Down Expand Up @@ -172,6 +178,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,


class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
can_return_lse_for_decode: bool = True

def __init__(
self,
Expand Down Expand Up @@ -239,7 +246,7 @@ def _forward_decode(
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

o = flash_attn_varlen_func(
attn_out = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
Expand All @@ -251,9 +258,16 @@ def _forward_decode(
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=self.need_to_return_lse_for_decode,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
)

return self._v_up_proj(o)
if self.need_to_return_lse_for_decode:
o, lse = attn_out
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
else:
o = attn_out
return o, None
3 changes: 3 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
return

if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
Copy link
Contributor

@youzhedian youzhedian Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the assert cannot be removed yet; some modifications to the MLA attention kernels are necessary. This is because, each query token may have a different seqlen_k on different DCP ranks.
For example, with dcp=2 and query_len=2, note as AB, if we treat this as a decode request:

  • The KV-cache for key k_A is stored on DCP rank 0, and k_B on DCP rank 1.
  • On DCP rank 0, both q_A and q_B should have seqlen_k = 1.
  • However, on DCP rank 1, q_A should have seqlen_k =0, and q_B should have seqlen_k = 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@youzhedian youzhedian Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 just recorrect seqlens_k for DCP decode, and it's trivial under query_len=1. But once query_len>1, the situation changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py#L670-L673 corrects anything thats classified as "decode" by

split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
i.e. anything with q_len <= reorder_batch_threshold so I still fail to see the issue?

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when q_len == 1, num_reqs == num_tokens

sorry i understand the issue now; we need a special causal mask for q_len > 1; i.e.

Normal:

k_toks >   0 1 2 3 4 5
q_toks v  _____________
       0 | 1 1 1
       1 | 1 1 1 1
       2 | 1 1 1 1 1
       3 | 1 1 1 1 1 1


DCP Rank 0:

k_toks >   0 2 4
q_toks v  _______
       0 | 1 1
       1 | 1 1
       2 | 1 1 1
       3 | 1 1 1 


DCP Rank 1:

k_toks >   1 3 5
q_toks v   ______
       0 | 1
       1 | 1 1
       2 | 1 1
       3 | 1 1 1

Apologies for the oversight on my side :face_palm: because of working on #22789 im not used to thinking about interleaved tokens being distributed because that approach i distributed contiguous blocks of tokens (full pages). Good catch! 🙏

I will add support for this mask in FA3 so we can combine DCP and FA3 (we should do the same for FlashMLA); in the meantime ill make the reorder_batch_threshold == 1 when DCP is turned on 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @MatthewBonanni who might have bandwidth before I do 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson @MatthewBonanni FYI #24864.

we also can separate mla dcp decode into two stage when query_len>1, context_kv use causal=Fasle, query_kv use caual=True, than we don't need hack a custom mask for all mla backends. I don't know which one is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect to have PRs up today with the custom mask, so we'll at least have that option. The WIP flash attention PR is: vllm-project/flash-attention#92

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, thanks! let's discuss on slack

"DCP not support reorder_batch_threshold > 1 now."
Expand Down