-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Description
Motivation.
This RFC is aimed to address the following issues:
-
The ViT right now is still pretty coupled with Text backbone attention. This RFC will further the effort to decouple the ViT and the text backbone attention.
-
Another pain point is that the overriding of the ViT logic is scattered all around the places. We should avoid doing ViT logic overriding in model definition classes. The platform class should define the logic of what ViT is supported and how it should be overwritten.
- The above logic is applied to general use case. As of the time of this RFC is proposed. This single logic is applied to all of the VL model below:
vllm/model_executor/models/qwen2_5_vl.py(this is shared by Qwen2.5 VL and Qwen3 VL)vllm/model_executor/models/dots_ocr.pyvllm/model_executor/models/ernie45_vl.pyvllm/model_executor/models/glm4_1v.pyvllm/model_executor/models/qwen2_vl.pyvllm/model_executor/models/siglip2navit.py
- If the model only supports a specific type of attention, the ViT overriding logic will be implemented explicitly in the model definition file (
model.py).
-
Since the introduction of
torch.compileinto the ViT, currently only starting with qwen vl model in PR [Misc][qwen2_5_vl][torch.compile] Enablesupports_torch_compileon generic nn.Module and demonstrate speedup on Qwen Vision model #23207 , the AMD ViT Code path are broken. New approach will try to accommodate this new feature.torch.compilehas brought a lot of performance improvement and we can now consider to replace triton kernels with pytorch native implementation as there are possibilities thattorch.compilecode is faster than customtriton kernelcode.
One of the proven case is [Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm #27748 . Thetorch.compile rotary_embeddingis faster thantriton rotary_embeddingfromflash_attn. -
Ensure ViT changes take into account the other model definition files
model.pyfiles, as current changes only involvesqwen2_5_vl.pywhich potentially affecting all othermodels.pyfiles.
vllm/model_executor/models/dots_ocr.pyvllm/model_executor/models/ernie45_vl.pyvllm/model_executor/models/glm4_1v.pyvllm/model_executor/models/qwen2_vl.pyvllm/model_executor/models/siglip2navit.py
Proposed Change.
NOTE: More changes to the details will come in while I am writing up a version with all these changes.
For a first quick reorganization of the ViT Attention while retaining current use of --mm-encoder-atttention-backend interface, introduced in the PR is #27061 , and a bugfix PR #27124 .
- First, we should shrink down the https://github.com/vllm-project/vllm/pull/27061/files#r2443909604 the
_Backendby introducing another_MHA_Backendregistry.
class _MHA_Backend(enum.Enum):
VLLM_FLASH_ATTN = enum.auto() # CUDA-only
FLASH_ATTN = enum.auto() # CUDA/ROCm
XFORMERS = enum.auto() # CUDA
ROCM_AITER_FA = enum.auto() # ROCM-only
TORCH_SDPA = enum.auto() # CUDA/ROCm/TPU/XPU/CPU
PALLAS = enum.auto() # TPU only- Make sure that the ViT attention is a platform specific. We should determine
platforminterface. We also perform override in theplatforminterface. We should avoid doing that in themodel.pyfiles.
-
get_vit_attn_backendin theplatforminterface has to be able to access the--mm-encoder-attn-backend. -
In the
platforminterface, we should only return_MHA_Backend, we should not return the functions. The functions should only be returned throughmaybe_get_vit_flash_attn_backend. If the model only supports a specific type of attention, the ViT overriding logic will be implemented explicitly in the model definition file (model.py).
class Platform:
...
@classmethod
def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]:
from vllm.attention.backends.registry import _MHA_Backend
return [
_MHA_Backend.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["_MHA_Backend"] = None,
) -> "_MHA_Backend":
# ViT Attention should be checked and override
# in the platform-specific implementation.
# we should not override this in any other places,
# like the model_executor/models/<model_name>.py
# So the steps are:
# 1. Check if the backend is None or not:
# a. If not, check if the backend is supported by the platform.
# b. If None, continue to the default selection logic.
# Import _Backend here to avoid circular import.
from vllm.attention.backends.registry import _MHA_Backend
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using default backend {_MHA_Backend.TORCH_SDPA} for vit attention"
)
return _MHA_Backend.TORCH_SDPA-
Honor
--mm-encoder-attn-backendso that we can write unit tests to test all different backends. AMD Instinct GPU is able to test all backends. Radeon GPUs only are able to use the TORCH_SDPA code path. -
We need to deprecate this line
https://github.com/vllm-project/vllm/blob/33a0ea5f3264b5b2f571b8a53357e10efcc94670/vllm/model_executor/models/vision.py#L96it is usingVLLM_ATTENTION_BACKENDwhich is for Text Backbone. The ViT should not use this environment variable. -
Added a
logger.info_onceso that users know which_MHA_Backendis selected in the end. -
Clean up cuda code path. Since
vllm.vllm_flash_attnis just a wrapper forflash_attnlibrary, on cuda, we always usevllm.vllm_flash_attninstead offlash_attn. -
.
Lines 120 to 125 in ba33e88
elif current_platform.is_cuda(): if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True
- Enable unit tests to test all different backends. Since there are large model sizes, we will check the VRAM size, if it is large enough, we run it. We provide such a unit test so that developers can run locally.
Feedback Period.
Changes
CC List.
@ywang96 @DarkLight1337 @Isotr0py
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.