diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 12f206637d7c..472095e13615 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -11,6 +11,7 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -98,6 +99,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # pre-allocated during capture. self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + # TODO(lucas): Until we add support for the DCP custom masking we need + # to restrict decodes to q_len == 1 when DCP is enabled. + self.__class__.reorder_batch_threshold = 1 \ + if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -172,6 +178,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -239,7 +246,7 @@ def _forward_decode( # to prevent invalid grid configuration during graph capture. max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) - o = flash_attn_varlen_func( + attn_out = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 @@ -251,9 +258,16 @@ def _forward_decode( block_table=attn_metadata.decode.block_table, softmax_scale=self.scale, causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, ) - return self._v_up_proj(o) + if self.need_to_return_lse_for_decode: + o, lse = attn_out + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] + else: + o = attn_out + return o, None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bbb2..8a90d3906ed3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,6 +440,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: assert self.reorder_batch_threshold == 1, \ "DCP not support reorder_batch_threshold > 1 now."