From 151e69b0cab3775658f4c02c517f805865cbb38f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 8 Sep 2025 17:42:21 +0000 Subject: [PATCH 1/5] fa MLA cp support Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashattn_mla.py | 13 ++++++++++--- vllm/v1/worker/gpu_model_runner.py | 3 --- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 12f206637d7c..9fbff8c48f7e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -172,6 +172,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -239,7 +240,7 @@ def _forward_decode( # to prevent invalid grid configuration during graph capture. max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) - o = flash_attn_varlen_func( + attn_out = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 @@ -251,9 +252,15 @@ def _forward_decode( block_table=attn_metadata.decode.block_table, softmax_scale=self.scale, causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, ) - - return self._v_up_proj(o) + + if self.need_to_return_lse_for_decode: + o, lse = attn_out + return o, lse + else: + o = attn_out + return o, None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bbb2..166ea4680788 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,9 +440,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ - "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, From cd3bafac2d29225e463be3d15ed6a61524ee30ad Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 8 Sep 2025 18:40:31 +0000 Subject: [PATCH 2/5] accuracy fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashattn_mla.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 9fbff8c48f7e..995fb8840089 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -257,10 +257,11 @@ def _forward_decode( scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, ) - + if self.need_to_return_lse_for_decode: o, lse = attn_out - return o, lse + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] else: o = attn_out return o, None From 39efba65b3d7a9dd6535e94e6d41dabf077fd98e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 9 Sep 2025 13:52:49 +0000 Subject: [PATCH 3/5] restrict to q_len == 1 Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashattn_mla.py | 7 +++++++ vllm/v1/worker/gpu_model_runner.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 995fb8840089..42ae1421bd65 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -20,6 +20,7 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata +from vllm.distributed.parallel_state import get_dcp_group logger = init_logger(__name__) @@ -98,6 +99,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # pre-allocated during capture. self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + # TODO(lucas): Until we add support for the DCP custom masking we need + # to restrict decodes to q_len == 1 when DCP is enabled. + self.__class__.reorder_batch_threshold = 1 \ + if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 166ea4680788..b3590342ebf4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,6 +440,12 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. + if self.dcp_world_size > 1: + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, From 4edbdd0b6fd8fd772c250e131a7cfbdf036cb24c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 9 Sep 2025 13:53:56 +0000 Subject: [PATCH 4/5] fix format Signed-off-by: Lucas Wilkinson --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b3590342ebf4..8a90d3906ed3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -445,7 +445,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: assert self.reorder_batch_threshold == 1, \ - "DCP not support reorder_batch_threshold > 1 now." + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, From a2e15ec27bc3b47d6f3f55ab92bf8dd621675be0 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 9 Sep 2025 14:38:11 -0400 Subject: [PATCH 5/5] Fix pre-commit Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 42ae1421bd65..472095e13615 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -11,6 +11,7 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -20,7 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -from vllm.distributed.parallel_state import get_dcp_group logger = init_logger(__name__) @@ -104,7 +104,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.__class__.reorder_batch_threshold = 1 \ if get_dcp_group().world_size > 1 else self.reorder_batch_threshold - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: