Skip to content

[RFC]: Reorganizing ViT Abstraction and Attention Selection Logic #27821

@tjtanaa

Description

@tjtanaa

Motivation.

This RFC is aimed to address the following issues:

  1. The ViT right now is still pretty coupled with Text backbone attention. This RFC will further the effort to decouple the ViT and the text backbone attention.

  2. Another pain point is that the overriding of the ViT logic is scattered all around the places. We should avoid doing ViT logic overriding in model definition classes. The platform class should define the logic of what ViT is supported and how it should be overwritten.

  • The above logic is applied to general use case. As of the time of this RFC is proposed. This single logic is applied to all of the VL model below:
    • vllm/model_executor/models/qwen2_5_vl.py (this is shared by Qwen2.5 VL and Qwen3 VL)
    • vllm/model_executor/models/dots_ocr.py
    • vllm/model_executor/models/ernie45_vl.py
    • vllm/model_executor/models/glm4_1v.py
    • vllm/model_executor/models/qwen2_vl.py
    • vllm/model_executor/models/siglip2navit.py
  • If the model only supports a specific type of attention, the ViT overriding logic will be implemented explicitly in the model definition file (model.py).
  1. Since the introduction of torch.compile into the ViT, currently only starting with qwen vl model in PR [Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model #23207 , the AMD ViT Code path are broken. New approach will try to accommodate this new feature. torch.compile has brought a lot of performance improvement and we can now consider to replace triton kernels with pytorch native implementation as there are possibilities that torch.compile code is faster than custom triton kernel code.
    One of the proven case is [Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm #27748 . The torch.compile rotary_embedding is faster than triton rotary_embedding from flash_attn .

  2. Ensure ViT changes take into account the other model definition files model.py files, as current changes only involves qwen2_5_vl.py which potentially affecting all other models.py files.

  • vllm/model_executor/models/dots_ocr.py
  • vllm/model_executor/models/ernie45_vl.py
  • vllm/model_executor/models/glm4_1v.py
  • vllm/model_executor/models/qwen2_vl.py
  • vllm/model_executor/models/siglip2navit.py

Proposed Change.

NOTE: More changes to the details will come in while I am writing up a version with all these changes.

For a first quick reorganization of the ViT Attention while retaining current use of --mm-encoder-atttention-backend interface, introduced in the PR is #27061 , and a bugfix PR #27124 .

  1. First, we should shrink down the https://github.com/vllm-project/vllm/pull/27061/files#r2443909604 the _Backend by introducing another _MHA_Backend registry.
class _MHA_Backend(enum.Enum):
    VLLM_FLASH_ATTN = enum.auto()  # CUDA-only
    FLASH_ATTN = enum.auto()  # CUDA/ROCm
    XFORMERS = enum.auto()  # CUDA
    ROCM_AITER_FA = enum.auto()  # ROCM-only
    TORCH_SDPA = enum.auto()  # CUDA/ROCm/TPU/XPU/CPU
    PALLAS = enum.auto() # TPU only
  1. Make sure that the ViT attention is a platform specific. We should determine platform interface. We also perform override in the platform interface. We should avoid doing that in the model.py files.
  • get_vit_attn_backend in the platform interface has to be able to access the --mm-encoder-attn-backend.

  • In the platform interface, we should only return _MHA_Backend, we should not return the functions. The functions should only be returned through maybe_get_vit_flash_attn_backend . If the model only supports a specific type of attention, the ViT overriding logic will be implemented explicitly in the model definition file (model.py).

class Platform:

    ...

    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]:
        from vllm.attention.backends.registry import _MHA_Backend

        return [
            _MHA_Backend.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["_MHA_Backend"] = None,
    ) -> "_MHA_Backend":
        # ViT Attention should be checked and override
        # in the platform-specific implementation.
        # we should not override this in any other places,
        # like the model_executor/models/<model_name>.py

        # So the steps are:
        # 1. Check if the backend is None or not:
        #    a. If not, check if the backend is supported by the platform.
        #    b. If None, continue to the default selection logic.

        # Import _Backend here to avoid circular import.
        from vllm.attention.backends.registry import _MHA_Backend

        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention"
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using default backend {_MHA_Backend.TORCH_SDPA} for vit attention"
        )
        return _MHA_Backend.TORCH_SDPA
  • Honor --mm-encoder-attn-backend so that we can write unit tests to test all different backends. AMD Instinct GPU is able to test all backends. Radeon GPUs only are able to use the TORCH_SDPA code path.

  • We need to deprecate this line https://github.com/vllm-project/vllm/blob/33a0ea5f3264b5b2f571b8a53357e10efcc94670/vllm/model_executor/models/vision.py#L96 it is using VLLM_ATTENTION_BACKEND which is for Text Backbone. The ViT should not use this environment variable.

  • Added a logger.info_once so that users know which _MHA_Backend is selected in the end.

  • Clean up cuda code path. Since vllm.vllm_flash_attn is just a wrapper for flash_attn library, on cuda, we always use vllm.vllm_flash_attn instead of flash_attn.

  • elif current_platform.is_cuda():
    if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
    torch.get_default_dtype()
    ):
    attn_backend = _Backend.FLASH_ATTN
    use_upstream_fa = True
    .

  1. Enable unit tests to test all different backends. Since there are large model sizes, we will check the VRAM size, if it is large enough, we run it. We provide such a unit test so that developers can run locally.

Feedback Period.

Changes

CC List.

@ywang96 @DarkLight1337 @Isotr0py

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions